Skip to content

Commit 37f1da8

Browse files
committed
feat: use new table in/out handling, add checks for larger chunk sizes and multiple chunks per input on table in/out functions
1 parent 01cea87 commit 37f1da8

File tree

4 files changed

+146
-22
lines changed

4 files changed

+146
-22
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "query-farm-airport-test-server"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "An Apache Arrow Flight server that is used to test the Airport extension for DuckDB."
55
authors = [
66
{ name = "Rusty Conover", email = "[email protected]" }
@@ -9,7 +9,7 @@ dependencies = [
99
"pyarrow>=20.0.0",
1010
"query-farm-flight-server",
1111
"duckdb>=1.3.1",
12-
"query-farm-duckdb-json-serialization>=0.1.1"
12+
"query-farm-duckdb-json-serialization>=0.1.1",
1313
]
1414
readme = "README.md"
1515
requires-python = ">= 3.12"

requirements-dev.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
-e file:.
1313
annotated-types==0.7.0
1414
# via pydantic
15-
boto3==1.39.9
15+
boto3==1.39.12
1616
# via query-farm-flight-server
17-
botocore==1.39.9
17+
botocore==1.39.12
1818
# via boto3
1919
# via s3transfer
2020
cache3==0.4.3
@@ -86,14 +86,14 @@ python-levenshtein==0.27.1
8686
# via query-farm-flight-server
8787
query-farm-duckdb-json-serialization==0.1.2
8888
# via query-farm-airport-test-server
89-
query-farm-flight-server==0.1.5
89+
query-farm-flight-server==0.1.8
9090
# via query-farm-airport-test-server
9191
rapidfuzz==3.13.0
9292
# via levenshtein
9393
ruff==0.11.2
9494
s3transfer==0.13.1
9595
# via boto3
96-
sentry-sdk==2.33.0
96+
sentry-sdk==2.33.2
9797
# via query-farm-flight-server
9898
six==1.17.0
9999
# via python-dateutil

requirements.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
-e file:.
1313
annotated-types==0.7.0
1414
# via pydantic
15-
boto3==1.39.9
15+
boto3==1.39.12
1616
# via query-farm-flight-server
17-
botocore==1.39.9
17+
botocore==1.39.12
1818
# via boto3
1919
# via s3transfer
2020
cache3==0.4.3
@@ -54,13 +54,13 @@ python-levenshtein==0.27.1
5454
# via query-farm-flight-server
5555
query-farm-duckdb-json-serialization==0.1.2
5656
# via query-farm-airport-test-server
57-
query-farm-flight-server==0.1.5
57+
query-farm-flight-server==0.1.8
5858
# via query-farm-airport-test-server
5959
rapidfuzz==3.13.0
6060
# via levenshtein
6161
s3transfer==0.13.1
6262
# via boto3
63-
sentry-sdk==2.33.0
63+
sentry-sdk==2.33.2
6464
# via query-farm-flight-server
6565
six==1.17.0
6666
# via python-dateutil

src/query_farm_airport_test_server/database_impl.py

Lines changed: 136 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ class TableFunction:
3737
output_schema_source: pa.Schema | TableFunctionDynamicOutput
3838

3939
# The function to call to process a chunk of rows.
40-
handler: Callable[
41-
[parameter_types.TableFunctionParameters, pa.Schema],
42-
Generator[pa.RecordBatch, pa.RecordBatch, pa.RecordBatch],
43-
]
40+
handler: Callable[[parameter_types.TableFunctionParameters, pa.Schema], parameter_types.TableFunctionInOutGenerator]
4441

4542
estimated_rows: int | Callable[[parameter_types.TableFunctionFlightInfo], int] = -1
4643

@@ -578,6 +575,11 @@ def dynamic_schema_handler_output_schema(
578575
return parameters.schema
579576

580577

578+
def in_out_long_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
579+
assert input_schema is not None
580+
return pa.schema([input_schema.field(0)])
581+
582+
581583
def in_out_schema_handler(parameters: pa.RecordBatch, input_schema: pa.Schema | None = None) -> pa.Schema:
582584
assert input_schema is not None
583585
return pa.schema([parameters.schema.field(0), input_schema.field(0)])
@@ -613,35 +615,43 @@ def in_out_echo_handler(
613615
def in_out_wide_handler(
614616
parameters: parameter_types.TableFunctionParameters,
615617
output_schema: pa.Schema,
616-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
618+
) -> parameter_types.TableFunctionInOutGenerator:
617619
result = output_schema.empty_table()
618620

619621
while True:
620-
input_chunk = yield result
622+
input_chunk = yield (result, True)
621623

622624
if input_chunk is None:
623625
break
624626

627+
if isinstance(input_chunk, bool):
628+
raise NotImplementedError("Not expecting continuing output for input chunk.")
629+
630+
chunk_length = len(input_chunk)
631+
625632
result = pa.RecordBatch.from_arrays(
626-
[[i] * len(input_chunk) for i in range(20)],
633+
[[i] * chunk_length for i in range(20)],
627634
schema=output_schema,
628635
)
629636

630-
return
637+
return None
631638

632639

633640
def in_out_handler(
634641
parameters: parameter_types.TableFunctionParameters,
635642
output_schema: pa.Schema,
636-
) -> Generator[pa.RecordBatch, pa.RecordBatch, None]:
643+
) -> parameter_types.TableFunctionInOutGenerator:
637644
result = output_schema.empty_table()
638645

639646
while True:
640-
input_chunk = yield result
647+
input_chunk = yield (result, True)
641648

642649
if input_chunk is None:
643650
break
644651

652+
if isinstance(input_chunk, bool):
653+
raise NotImplementedError("Not expecting continuing output for input chunk.")
654+
645655
assert parameters.parameters is not None
646656
parameter_value = parameters.parameters.column(0).to_pylist()[0]
647657

@@ -654,7 +664,75 @@ def in_out_handler(
654664
schema=output_schema,
655665
)
656666

657-
return pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)
667+
return [pa.RecordBatch.from_arrays([["last"], ["row"]], schema=output_schema)]
668+
669+
670+
def in_out_long_handler(
671+
parameters: parameter_types.TableFunctionParameters,
672+
output_schema: pa.Schema,
673+
) -> parameter_types.TableFunctionInOutGenerator:
674+
result = output_schema.empty_table()
675+
676+
while True:
677+
input_chunk = yield (result, True)
678+
679+
if input_chunk is None:
680+
break
681+
682+
if isinstance(input_chunk, bool):
683+
raise NotImplementedError("Not expecting continuing output for input chunk.")
684+
685+
# Return the input chunk ten times.
686+
multiplier = 10
687+
copied_results = [
688+
pa.RecordBatch.from_arrays(
689+
[
690+
input_chunk.column(0),
691+
],
692+
schema=output_schema,
693+
)
694+
for index in range(multiplier)
695+
]
696+
697+
for item in copied_results[0:-1]:
698+
yield (item, False)
699+
result = copied_results[-1]
700+
701+
return None
702+
703+
704+
def in_out_huge_chunk_handler(
705+
parameters: parameter_types.TableFunctionParameters,
706+
output_schema: pa.Schema,
707+
) -> parameter_types.TableFunctionInOutGenerator:
708+
result = output_schema.empty_table()
709+
multiplier = 10
710+
chunk_length = 5000
711+
712+
while True:
713+
input_chunk = yield (result, True)
714+
715+
if input_chunk is None:
716+
break
717+
718+
if isinstance(input_chunk, bool):
719+
raise NotImplementedError("Not expecting continuing output for input chunk.")
720+
721+
for index, _i in enumerate(range(multiplier)):
722+
output = pa.RecordBatch.from_arrays(
723+
[list(range(chunk_length)), list([index] * chunk_length)],
724+
schema=output_schema,
725+
)
726+
if index < multiplier - 1:
727+
yield (output, False)
728+
else:
729+
result = output
730+
731+
# test big chunks returned as the last results.
732+
return [
733+
pa.RecordBatch.from_arrays([list(range(chunk_length)), list([footer_id] * chunk_length)], schema=output_schema)
734+
for footer_id in (-1, -2, -3)
735+
]
658736

659737

660738
def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoint]:
@@ -739,6 +817,17 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
739817
table_functions_by_name=CaseInsensitiveDict(),
740818
tables_by_name=CaseInsensitiveDict(
741819
{
820+
"big_chunk": TableInfo(
821+
table_versions=[
822+
pa.Table.from_arrays(
823+
[
824+
list(range(100000)),
825+
],
826+
schema=pa.schema([pa.field("id", pa.int64())]),
827+
)
828+
],
829+
row_id_counter=0,
830+
),
742831
"employees": TableInfo(
743832
table_versions=[
744833
pa.Table.from_arrays(
@@ -775,7 +864,7 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
775864
)
776865
],
777866
row_id_counter=2,
778-
)
867+
),
779868
}
780869
),
781870
)
@@ -948,6 +1037,41 @@ def collatz_steps(n: int) -> list[int]:
9481037
),
9491038
handler=in_out_handler,
9501039
),
1040+
"test_table_in_out_long": TableFunction(
1041+
input_schema=pa.schema(
1042+
[
1043+
pa.field(
1044+
"table_input",
1045+
pa.string(),
1046+
metadata={"is_table_type": "1"},
1047+
),
1048+
]
1049+
),
1050+
output_schema_source=TableFunctionDynamicOutput(
1051+
schema_creator=in_out_long_schema_handler,
1052+
default_values=(
1053+
pa.RecordBatch.from_arrays(
1054+
[pa.array([1], type=pa.int32())],
1055+
schema=pa.schema([pa.field("input", pa.int32())]),
1056+
),
1057+
pa.schema([pa.field("input", pa.int32())]),
1058+
),
1059+
),
1060+
handler=in_out_long_handler,
1061+
),
1062+
"test_table_in_out_huge": TableFunction(
1063+
input_schema=pa.schema(
1064+
[
1065+
pa.field(
1066+
"table_input",
1067+
pa.string(),
1068+
metadata={"is_table_type": "1"},
1069+
),
1070+
]
1071+
),
1072+
output_schema_source=pa.schema([("multiplier", pa.int64()), ("value", pa.int64())]),
1073+
handler=in_out_huge_chunk_handler,
1074+
),
9511075
"test_table_in_out_wide": TableFunction(
9521076
input_schema=pa.schema(
9531077
[

0 commit comments

Comments
 (0)