Skip to content

Normalise Execution Response (clean backend interfaces) #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
138c2ae
[squash from exec-sea] bring over execution phase changes
varun-edachali-dbx Jun 9, 2025
3e3ab94
remove excess test
varun-edachali-dbx Jun 9, 2025
4a78165
add docstring
varun-edachali-dbx Jun 9, 2025
0dac4aa
remvoe exec func in sea backend
varun-edachali-dbx Jun 9, 2025
1b794c7
remove excess files
varun-edachali-dbx Jun 9, 2025
da5a6fe
remove excess models
varun-edachali-dbx Jun 9, 2025
686ade4
remove excess sea backend tests
varun-edachali-dbx Jun 9, 2025
31e6c83
cleanup
varun-edachali-dbx Jun 9, 2025
69ea238
re-introduce get_schema_desc
varun-edachali-dbx Jun 9, 2025
66d7517
remove SeaResultSet
varun-edachali-dbx Jun 9, 2025
71feef9
clean imports and attributes
varun-edachali-dbx Jun 9, 2025
ae9862f
pass CommandId to ExecResp
varun-edachali-dbx Jun 9, 2025
d8aa69e
remove changes in types
varun-edachali-dbx Jun 9, 2025
db139bc
add back essential types (ExecResponse, from_sea_state)
varun-edachali-dbx Jun 9, 2025
b977b12
fix fetch types
varun-edachali-dbx Jun 9, 2025
da615c0
excess imports
varun-edachali-dbx Jun 9, 2025
0da04a6
reduce diff by maintaining logs
varun-edachali-dbx Jun 9, 2025
ea9d456
fix int test types
varun-edachali-dbx Jun 9, 2025
d97463b
move guid_to_hex_id import to utils
varun-edachali-dbx Jun 9, 2025
139e246
reduce diff in guid utils import
varun-edachali-dbx Jun 9, 2025
e3ee4e4
move arrow_schema_bytes back into ExecuteResult
varun-edachali-dbx Jun 9, 2025
f448a8f
maintain log
varun-edachali-dbx Jun 9, 2025
82ca1ee
remove un-necessary assignment
varun-edachali-dbx Jun 9, 2025
e96a078
remove un-necessary tuple response
varun-edachali-dbx Jun 9, 2025
27158b1
remove un-ncessary verbose mocking
varun-edachali-dbx Jun 9, 2025
d3200c4
move Queue construction to ResultSert
varun-edachali-dbx Jun 10, 2025
8a014f0
move description to List[Tuple]
varun-edachali-dbx Jun 10, 2025
39c41ab
frmatting (black)
varun-edachali-dbx Jun 10, 2025
2cd04df
reduce diff (remove explicit tuple conversion)
varun-edachali-dbx Jun 10, 2025
067a019
remove has_more_rows from ExecuteResponse
varun-edachali-dbx Jun 10, 2025
48c83e0
remove un-necessary has_more_rows aclc
varun-edachali-dbx Jun 10, 2025
281a9e9
default has_more_rows to True
varun-edachali-dbx Jun 10, 2025
192901d
return has_more_rows from ExecResponse conversion during GetRespMetadata
varun-edachali-dbx Jun 10, 2025
55f5c45
remove unnecessary replacement
varun-edachali-dbx Jun 10, 2025
edc36b5
better mocked backend naming
varun-edachali-dbx Jun 10, 2025
81280e7
remove has_more_rows test in ExecuteResponse
varun-edachali-dbx Jun 10, 2025
c1d3be2
introduce replacement of original has_more_rows read test
varun-edachali-dbx Jun 10, 2025
5ee4136
call correct method in test_use_arrow_schema
varun-edachali-dbx Jun 10, 2025
b881ab0
call correct method in test_fall_back_to_hive_schema
varun-edachali-dbx Jun 10, 2025
53bf715
re-introduce result response read test
varun-edachali-dbx Jun 10, 2025
45a32be
simplify test
varun-edachali-dbx Jun 10, 2025
e3fe299
remove excess fetch_results mocks
varun-edachali-dbx Jun 10, 2025
e8038d3
more minimal changes to thrift_backend tests
varun-edachali-dbx Jun 10, 2025
2f6ec19
move back to old table types
varun-edachali-dbx Jun 10, 2025
73bc282
remove outdated arrow_schema_bytes return
varun-edachali-dbx Jun 10, 2025
7c483f2
remove duplicate import
varun-edachali-dbx Jun 11, 2025
8cbeb08
rephrase model docstrings to explicitly denote that they are represen…
varun-edachali-dbx Jun 11, 2025
36b9cfb
has_more_rows -> is_direct_results
varun-edachali-dbx Jun 11, 2025
c04d583
switch docstring format to align with Connection class
varun-edachali-dbx Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/databricks/sql/backend/databricks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.backend.types import SessionId, CommandId, CommandState
from databricks.sql.utils import ExecuteResponse
from databricks.sql.types import SSLOptions

# Forward reference for type hints
from typing import TYPE_CHECKING
Expand Down
4 changes: 2 additions & 2 deletions src/databricks/sql/backend/sea/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@dataclass
class CreateSessionRequest:
"""Request to create a new session."""
"""Representation of a request to create a new session."""

warehouse_id: str
session_confs: Optional[Dict[str, str]] = None
Expand All @@ -29,7 +29,7 @@ def to_dict(self) -> Dict[str, Any]:

@dataclass
class DeleteSessionRequest:
"""Request to delete a session."""
"""Representation of a request to delete a session."""

warehouse_id: str
session_id: str
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@dataclass
class CreateSessionResponse:
"""Response from creating a new session."""
"""Representation of the response from creating a new session."""

session_id: str

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import requests
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
from typing import Callable, Dict, Any, Optional, List, Tuple
from urllib.parse import urljoin

from databricks.sql.auth.authenticators import AuthProvider
Expand Down
153 changes: 88 additions & 65 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@
import logging
import math
import time
import uuid
import threading
from typing import List, Optional, Union, Any, TYPE_CHECKING
from typing import List, Union, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Cursor
from databricks.sql.result_set import ResultSet, ThriftResultSet

from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
from databricks.sql.backend.types import (
CommandState,
SessionId,
CommandId,
BackendType,
ExecuteResponse,
)
from databricks.sql.backend.utils import guid_to_hex_id


try:
import pyarrow
except ImportError:
Expand All @@ -42,7 +40,7 @@
)

from databricks.sql.utils import (
ExecuteResponse,
ResultSetQueueFactory,
_bound,
RequestErrorInfo,
NoRetryReason,
Expand All @@ -53,6 +51,7 @@
)
from databricks.sql.types import SSLOptions
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.result_set import ResultSet, ThriftResultSet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -758,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
)
direct_results = resp.directResults
has_been_closed_server_side = direct_results and direct_results.closeOperation
has_more_rows = (

is_direct_results = (
(not direct_results)
or (not direct_results.resultSet)
or direct_results.resultSet.hasMoreRows
)

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
Expand All @@ -778,42 +779,28 @@ def _results_message_to_execute_response(self, resp, operation_state):
schema_bytes = None

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

arrow_queue_opt = ResultSetQueueFactory.build_queue(
row_set_type=t_result_set_metadata_resp.resultFormat,
t_row_set=direct_results.resultSet.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
)
else:
arrow_queue_opt = None

command_id = CommandId.from_thrift_handle(resp.operationHandle)

return ExecuteResponse(
arrow_queue=arrow_queue_opt,
status=CommandState.from_thrift_state(operation_state),
has_been_closed_server_side=has_been_closed_server_side,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
status = CommandState.from_thrift_state(operation_state)
if status is None:
raise ValueError(f"Unknown command state: {operation_state}")

execute_response = ExecuteResponse(
command_id=command_id,
status=status,
description=description,
has_been_closed_server_side=has_been_closed_server_side,
lz4_compressed=lz4_compressed,
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
arrow_schema_bytes=schema_bytes,
result_format=t_result_set_metadata_resp.resultFormat,
)

return execute_response, is_direct_results

def get_execution_result(
self, command_id: CommandId, cursor: "Cursor"
) -> "ResultSet":
from databricks.sql.result_set import ThriftResultSet

thrift_handle = command_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift command ID")
Expand All @@ -835,9 +822,6 @@ def get_execution_result(

t_result_set_metadata_resp = resp.resultSetMetadata

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema
)
Expand All @@ -852,26 +836,21 @@ def get_execution_result(
else:
schema_bytes = None

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
t_row_set=resp.results,
arrow_schema_bytes=schema_bytes,
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_options=self._ssl_options,
)
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
is_direct_results = resp.hasMoreRows

status = self.get_query_state(command_id)

execute_response = ExecuteResponse(
arrow_queue=queue,
status=CommandState.from_thrift_state(resp.status),
command_id=command_id,
status=status,
description=description,
has_been_closed_server_side=False,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
command_id=command_id,
description=description,
arrow_schema_bytes=schema_bytes,
result_format=t_result_set_metadata_resp.resultFormat,
)

return ThriftResultSet(
Expand All @@ -881,6 +860,10 @@ def get_execution_result(
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
t_row_set=resp.results,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
Expand Down Expand Up @@ -947,8 +930,6 @@ def execute_command(
async_op=False,
enforce_embedded_schema_correctness=False,
) -> Union["ResultSet", None]:
from databricks.sql.result_set import ThriftResultSet

thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift session ID")
Expand Down Expand Up @@ -995,7 +976,13 @@ def execute_command(
self._handle_execute_response_async(resp, cursor)
return None
else:
execute_response = self._handle_execute_response(resp, cursor)
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -1004,6 +991,10 @@ def execute_command(
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=use_cloud_fetch,
t_row_set=t_row_set,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def get_catalogs(
Expand All @@ -1013,8 +1004,6 @@ def get_catalogs(
max_bytes: int,
cursor: "Cursor",
) -> "ResultSet":
from databricks.sql.result_set import ThriftResultSet

thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift session ID")
Expand All @@ -1027,7 +1016,13 @@ def get_catalogs(
)
resp = self.make_request(self._client.GetCatalogs, req)

execute_response = self._handle_execute_response(resp, cursor)
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -1036,6 +1031,10 @@ def get_catalogs(
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
t_row_set=t_row_set,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def get_schemas(
Expand All @@ -1047,8 +1046,6 @@ def get_schemas(
catalog_name=None,
schema_name=None,
) -> "ResultSet":
from databricks.sql.result_set import ThriftResultSet

thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift session ID")
Expand All @@ -1063,7 +1060,13 @@ def get_schemas(
)
resp = self.make_request(self._client.GetSchemas, req)

execute_response = self._handle_execute_response(resp, cursor)
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -1072,6 +1075,10 @@ def get_schemas(
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
t_row_set=t_row_set,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def get_tables(
Expand All @@ -1085,8 +1092,6 @@ def get_tables(
table_name=None,
table_types=None,
) -> "ResultSet":
from databricks.sql.result_set import ThriftResultSet

thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift session ID")
Expand All @@ -1103,7 +1108,13 @@ def get_tables(
)
resp = self.make_request(self._client.GetTables, req)

execute_response = self._handle_execute_response(resp, cursor)
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -1112,6 +1123,10 @@ def get_tables(
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
t_row_set=t_row_set,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def get_columns(
Expand All @@ -1125,8 +1140,6 @@ def get_columns(
table_name=None,
column_name=None,
) -> "ResultSet":
from databricks.sql.result_set import ThriftResultSet

thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
raise ValueError("Not a valid Thrift session ID")
Expand All @@ -1143,7 +1156,13 @@ def get_columns(
)
resp = self.make_request(self._client.GetColumns, req)

execute_response = self._handle_execute_response(resp, cursor)
execute_response, is_direct_results = self._handle_execute_response(
resp, cursor
)

t_row_set = None
if resp.directResults and resp.directResults.resultSet:
t_row_set = resp.directResults.resultSet.results

return ThriftResultSet(
connection=cursor.connection,
Expand All @@ -1152,6 +1171,10 @@ def get_columns(
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
t_row_set=t_row_set,
max_download_threads=self.max_download_threads,
ssl_options=self._ssl_options,
is_direct_results=is_direct_results,
)

def _handle_execute_response(self, resp, cursor):
Expand Down
Loading
Loading