Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _convert_json_types(self, row: List[str]) -> List[Any]:
column_name=column_name,
precision=precision,
scale=scale,
timestamp_format=self.connection.non_arrow_timestamp_format,
)
converted_row.append(converted_value)

Expand Down
5 changes: 5 additions & 0 deletions src/databricks/sql/backend/sea/utils/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from dateutil import parser
from typing import Callable, Dict, Optional

from databricks.sql.utils import parse_timestamp

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -162,6 +164,9 @@ def convert_value(
precision = kwargs.get("precision", None)
scale = kwargs.get("scale", None)
return converter_func(value, precision, scale)
elif sql_type == SqlType.TIMESTAMP:
timestamp_format = kwargs.get("timestamp_format", None)
return parse_timestamp(value, timestamp_format)
else:
return converter_func(value)
Comment on lines +167 to 171
except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def fetch_results(
description,
chunk_id: int,
use_cloud_fetch=True,
timestamp_format=None,
):
thrift_handle = command_id.to_thrift_handle()
if not thrift_handle:
Expand Down Expand Up @@ -1336,6 +1337,7 @@ def fetch_results(
statement_id=command_id.to_hex_guid(),
chunk_id=chunk_id,
http_client=self._http_client,
timestamp_format=timestamp_format,
)

return (
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def read(self) -> Optional[OAuthToken]:
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self.non_arrow_timestamp_format = kwargs.get("non_arrow_timestamp_format", None)
self._cursors = [] # type: List[Cursor]
Comment on lines 295 to 299
self.telemetry_batch_size = kwargs.get(
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
Expand Down
2 changes: 2 additions & 0 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
statement_id=execute_response.command_id.to_hex_guid(),
chunk_id=self.num_chunks,
http_client=connection.http_client,
timestamp_format=connection.non_arrow_timestamp_format,
)
if t_row_set.resultLinks:
self.num_chunks += len(t_row_set.resultLinks)
Expand Down Expand Up @@ -281,6 +282,7 @@ def _fill_results_buffer(self):
description=self.description,
use_cloud_fetch=self._use_cloud_fetch,
chunk_id=self.num_chunks,
timestamp_format=self.connection.non_arrow_timestamp_format,
)
self.results = results
self.has_more_rows = has_more_rows
Expand Down
40 changes: 37 additions & 3 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from decimal import Decimal
from enum import Enum
import re
import pytz

import lz4.frame

Expand Down Expand Up @@ -53,6 +54,32 @@ def get_session_config_value(
return None


def parse_timestamp(
value: str, timestamp_format: Optional[str] = None
) -> datetime.datetime:
"""Parse a timestamp string into a datetime object.

If timestamp_format is provided, tries strptime first and falls back to
dateutil.parser.parse on ValueError. If timestamp_format is None, uses
dateutil.parser.parse directly.

Args:
value: The timestamp string to parse.
timestamp_format: An optional strptime-compatible format string.

Returns:
A datetime.datetime object.
"""
if timestamp_format is not None:
try:
return datetime.datetime.strptime(value, timestamp_format).replace(
tzinfo=pytz.UTC
)
Comment on lines +75 to +77
except ValueError:
return parser.parse(value)
return parser.parse(value)


class ResultSetQueue(ABC):
@abstractmethod
def next_n_rows(self, num_rows: int):
Expand Down Expand Up @@ -81,6 +108,7 @@ def build_queue(
http_client,
lz4_compressed: bool = True,
description: List[Tuple] = [],
timestamp_format: Optional[str] = None,
) -> ResultSetQueue:
"""
Factory method to build a result set queue for Thrift backend.
Expand All @@ -93,6 +121,7 @@ def build_queue(
description (List[List[Any]]): Hive table schema description.
max_download_threads (int): Maximum number of downloader thread pool threads.
ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
timestamp_format: Optional strptime-compatible format for timestamp parsing.

Returns:
ResultSetQueue
Expand All @@ -112,7 +141,7 @@ def build_queue(
)

converted_column_table = convert_to_assigned_datatypes_in_column_table(
column_table, description
column_table, description, timestamp_format=timestamp_format
)

return ColumnQueue(ColumnTable(converted_column_table, column_names))
Expand Down Expand Up @@ -760,7 +789,9 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
return pyarrow.Table.from_arrays(new_columns, schema=new_schema)


def convert_to_assigned_datatypes_in_column_table(column_table, description):
def convert_to_assigned_datatypes_in_column_table(
column_table, description, timestamp_format=None
):

converted_column_table = []
for i, col in enumerate(column_table):
Expand All @@ -774,7 +805,10 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description):
)
elif description[i][1] == "timestamp":
converted_column_table.append(
tuple((v if v is None else parser.parse(v)) for v in col)
tuple(
(v if v is None else parse_timestamp(v, timestamp_format))
for v in col
)
)
else:
converted_column_table.append(col)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_fetches.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def fetch_results(
description,
use_cloud_fetch=True,
chunk_id=0,
timestamp_format=None,
):
nonlocal batch_index
results = FetchTests.make_arrow_queue(batch_list[batch_index])
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_sea_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import datetime
import decimal
import pytz
from unittest.mock import Mock, patch

from databricks.sql.backend.sea.utils.conversion import SqlType, SqlTypeConverter
Expand Down Expand Up @@ -147,3 +148,37 @@ def test_convert_unsupported_type(self):
SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT, None)
== "complex_value"
)

def test_convert_timestamp_with_format(self):
"""Test converting timestamp with an explicit strptime format."""
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = SqlTypeConverter.convert_value(
"2023-12-31 12:30:00.123000",
SqlType.TIMESTAMP,
None,
timestamp_format=fmt,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000, tzinfo=pytz.UTC)

def test_convert_timestamp_with_format_fallback(self):
"""Test that non-matching format falls back to dateutil."""
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = SqlTypeConverter.convert_value(
"08-Mar-2024 14:30:15",
SqlType.TIMESTAMP,
None,
timestamp_format=fmt,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2024, 3, 8, 14, 30, 15)

def test_convert_timestamp_without_format(self):
"""Test converting timestamp without explicit format uses dateutil."""
result = SqlTypeConverter.convert_value(
"2023-01-15T12:30:45",
SqlType.TIMESTAMP,
None,
)
assert isinstance(result, datetime.datetime)
assert result == datetime.datetime(2023, 1, 15, 12, 30, 45)
40 changes: 40 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import datetime
from datetime import timezone, timedelta
import pytest
import pytz
from databricks.sql.utils import (
convert_to_assigned_datatypes_in_column_table,
parse_timestamp,
ColumnTable,
concat_table_chunks,
serialize_query_tags,
Expand Down Expand Up @@ -224,3 +226,41 @@ def test_serialize_query_tags_all_none_values(self):
query_tags = {"key1": None, "key2": None, "key3": None}
result = serialize_query_tags(query_tags)
assert result == "key1,key2,key3"


class TestParseTimestamp:
def test_no_format_uses_dateutil(self):
result = parse_timestamp("2023-12-31 12:30:00")
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0)

def test_matching_format_uses_strptime(self):
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = parse_timestamp("2023-12-31 12:30:00.123000", fmt)
assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000, tzinfo=pytz.UTC)

def test_non_matching_format_falls_back_to_dateutil(self):
fmt = "%Y-%m-%d %H:%M:%S.%f"
# This doesn't match the format, so should fall back to dateutil
result = parse_timestamp("08-Mar-2024 14:30:15", fmt)
assert result == datetime.datetime(2024, 3, 8, 14, 30, 15)

def test_convert_column_table_with_timestamp_format(self):
description = [
("ts_col", "timestamp", None, None, None, None, None),
]
column_table = [("2023-12-31 12:30:00.000000",)]
fmt = "%Y-%m-%d %H:%M:%S.%f"
result = convert_to_assigned_datatypes_in_column_table(
column_table, description, timestamp_format=fmt
)
assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0, tzinfo=pytz.UTC)

def test_convert_column_table_without_timestamp_format(self):
description = [
("ts_col", "timestamp", None, None, None, None, None),
]
column_table = [("2023-12-31 12:30:00",)]
result = convert_to_assigned_datatypes_in_column_table(
column_table, description
)
assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0)
Loading