diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 17838ed81..8035b527c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -104,6 +104,7 @@ def _convert_json_types(self, row: List[str]) -> List[Any]: column_name=column_name, precision=precision, scale=scale, + timestamp_format=self.connection.non_arrow_timestamp_format, ) converted_row.append(converted_value) diff --git a/src/databricks/sql/backend/sea/utils/conversion.py b/src/databricks/sql/backend/sea/utils/conversion.py index 69c6dfbe2..0836aa0a1 100644 --- a/src/databricks/sql/backend/sea/utils/conversion.py +++ b/src/databricks/sql/backend/sea/utils/conversion.py @@ -11,6 +11,8 @@ from dateutil import parser from typing import Callable, Dict, Optional +from databricks.sql.utils import parse_timestamp + logger = logging.getLogger(__name__) @@ -162,6 +164,9 @@ def convert_value( precision = kwargs.get("precision", None) scale = kwargs.get("scale", None) return converter_func(value, precision, scale) + elif sql_type == SqlType.TIMESTAMP: + timestamp_format = kwargs.get("timestamp_format", None) + return parse_timestamp(value, timestamp_format) else: return converter_func(value) except Exception as e: diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index e23f3389b..d1bcaa766 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -1296,6 +1296,7 @@ def fetch_results( description, chunk_id: int, use_cloud_fetch=True, + timestamp_format=None, ): thrift_handle = command_id.to_thrift_handle() if not thrift_handle: @@ -1336,6 +1337,7 @@ def fetch_results( statement_id=command_id.to_hex_guid(), chunk_id=chunk_id, http_client=self._http_client, + timestamp_format=timestamp_format, ) return ( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 2aeea175e..d58191427 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -295,6 +295,7 @@ def read(self) -> Optional[OAuthToken]: self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self.non_arrow_timestamp_format = kwargs.get("non_arrow_timestamp_format", None) self._cursors = [] # type: List[Cursor] self.telemetry_batch_size = kwargs.get( "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 6c4c3a43a..3c0834e98 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -245,6 +245,7 @@ def __init__( statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, http_client=connection.http_client, + timestamp_format=connection.non_arrow_timestamp_format, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) @@ -281,6 +282,7 @@ def _fill_results_buffer(self): description=self.description, use_cloud_fetch=self._use_cloud_fetch, chunk_id=self.num_chunks, + timestamp_format=self.connection.non_arrow_timestamp_format, ) self.results = results self.has_more_rows = has_more_rows diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index b1fff7202..4a00538d1 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -10,6 +10,7 @@ from decimal import Decimal from enum import Enum import re +import pytz import lz4.frame @@ -53,6 +54,32 @@ def get_session_config_value( return None +def parse_timestamp( + value: str, timestamp_format: Optional[str] = None +) -> datetime.datetime: + """Parse a timestamp string into a datetime object. + + If timestamp_format is provided, tries strptime first and falls back to + dateutil.parser.parse on ValueError. If timestamp_format is None, uses + dateutil.parser.parse directly. + + Args: + value: The timestamp string to parse. + timestamp_format: An optional strptime-compatible format string. + + Returns: + A datetime.datetime object. + """ + if timestamp_format is not None: + try: + return datetime.datetime.strptime(value, timestamp_format).replace( + tzinfo=pytz.UTC + ) + except ValueError: + return parser.parse(value) + return parser.parse(value) + + class ResultSetQueue(ABC): @abstractmethod def next_n_rows(self, num_rows: int): @@ -81,6 +108,7 @@ def build_queue( http_client, lz4_compressed: bool = True, description: List[Tuple] = [], + timestamp_format: Optional[str] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue for Thrift backend. @@ -93,6 +121,7 @@ def build_queue( description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue + timestamp_format: Optional strptime-compatible format for timestamp parsing. Returns: ResultSetQueue @@ -112,7 +141,7 @@ def build_queue( ) converted_column_table = convert_to_assigned_datatypes_in_column_table( - column_table, description + column_table, description, timestamp_format=timestamp_format ) return ColumnQueue(ColumnTable(converted_column_table, column_names)) @@ -760,7 +789,9 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": return pyarrow.Table.from_arrays(new_columns, schema=new_schema) -def convert_to_assigned_datatypes_in_column_table(column_table, description): +def convert_to_assigned_datatypes_in_column_table( + column_table, description, timestamp_format=None +): converted_column_table = [] for i, col in enumerate(column_table): @@ -774,7 +805,10 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): ) elif description[i][1] == "timestamp": converted_column_table.append( - tuple((v if v is None else parser.parse(v)) for v in col) + tuple( + (v if v is None else parse_timestamp(v, timestamp_format)) + for v in col + ) ) else: converted_column_table.append(col) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7a0706838..e281fa870 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -80,6 +80,7 @@ def fetch_results( description, use_cloud_fetch=True, chunk_id=0, + timestamp_format=None, ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) diff --git a/tests/unit/test_sea_conversion.py b/tests/unit/test_sea_conversion.py index 234cca868..8bf85fbbb 100644 --- a/tests/unit/test_sea_conversion.py +++ b/tests/unit/test_sea_conversion.py @@ -7,6 +7,7 @@ import pytest import datetime import decimal +import pytz from unittest.mock import Mock, patch from databricks.sql.backend.sea.utils.conversion import SqlType, SqlTypeConverter @@ -147,3 +148,37 @@ def test_convert_unsupported_type(self): SqlTypeConverter.convert_value("complex_value", SqlType.STRUCT, None) == "complex_value" ) + + def test_convert_timestamp_with_format(self): + """Test converting timestamp with an explicit strptime format.""" + fmt = "%Y-%m-%d %H:%M:%S.%f" + result = SqlTypeConverter.convert_value( + "2023-12-31 12:30:00.123000", + SqlType.TIMESTAMP, + None, + timestamp_format=fmt, + ) + assert isinstance(result, datetime.datetime) + assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000, tzinfo=pytz.UTC) + + def test_convert_timestamp_with_format_fallback(self): + """Test that non-matching format falls back to dateutil.""" + fmt = "%Y-%m-%d %H:%M:%S.%f" + result = SqlTypeConverter.convert_value( + "08-Mar-2024 14:30:15", + SqlType.TIMESTAMP, + None, + timestamp_format=fmt, + ) + assert isinstance(result, datetime.datetime) + assert result == datetime.datetime(2024, 3, 8, 14, 30, 15) + + def test_convert_timestamp_without_format(self): + """Test converting timestamp without explicit format uses dateutil.""" + result = SqlTypeConverter.convert_value( + "2023-01-15T12:30:45", + SqlType.TIMESTAMP, + None, + ) + assert isinstance(result, datetime.datetime) + assert result == datetime.datetime(2023, 1, 15, 12, 30, 45) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 687bdd391..7b3aad8c9 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -2,8 +2,10 @@ import datetime from datetime import timezone, timedelta import pytest +import pytz from databricks.sql.utils import ( convert_to_assigned_datatypes_in_column_table, + parse_timestamp, ColumnTable, concat_table_chunks, serialize_query_tags, @@ -224,3 +226,41 @@ def test_serialize_query_tags_all_none_values(self): query_tags = {"key1": None, "key2": None, "key3": None} result = serialize_query_tags(query_tags) assert result == "key1,key2,key3" + + +class TestParseTimestamp: + def test_no_format_uses_dateutil(self): + result = parse_timestamp("2023-12-31 12:30:00") + assert result == datetime.datetime(2023, 12, 31, 12, 30, 0) + + def test_matching_format_uses_strptime(self): + fmt = "%Y-%m-%d %H:%M:%S.%f" + result = parse_timestamp("2023-12-31 12:30:00.123000", fmt) + assert result == datetime.datetime(2023, 12, 31, 12, 30, 0, 123000, tzinfo=pytz.UTC) + + def test_non_matching_format_falls_back_to_dateutil(self): + fmt = "%Y-%m-%d %H:%M:%S.%f" + # This doesn't match the format, so should fall back to dateutil + result = parse_timestamp("08-Mar-2024 14:30:15", fmt) + assert result == datetime.datetime(2024, 3, 8, 14, 30, 15) + + def test_convert_column_table_with_timestamp_format(self): + description = [ + ("ts_col", "timestamp", None, None, None, None, None), + ] + column_table = [("2023-12-31 12:30:00.000000",)] + fmt = "%Y-%m-%d %H:%M:%S.%f" + result = convert_to_assigned_datatypes_in_column_table( + column_table, description, timestamp_format=fmt + ) + assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0, tzinfo=pytz.UTC) + + def test_convert_column_table_without_timestamp_format(self): + description = [ + ("ts_col", "timestamp", None, None, None, None, None), + ] + column_table = [("2023-12-31 12:30:00",)] + result = convert_to_assigned_datatypes_in_column_table( + column_table, description + ) + assert result[0][0] == datetime.datetime(2023, 12, 31, 12, 30, 0)