@@ -37,10 +37,7 @@ class TableFunction:
37
37
output_schema_source : pa .Schema | TableFunctionDynamicOutput
38
38
39
39
# The function to call to process a chunk of rows.
40
- handler : Callable [
41
- [parameter_types .TableFunctionParameters , pa .Schema ],
42
- Generator [pa .RecordBatch , pa .RecordBatch , pa .RecordBatch ],
43
- ]
40
+ handler : Callable [[parameter_types .TableFunctionParameters , pa .Schema ], parameter_types .TableFunctionInOutGenerator ]
44
41
45
42
estimated_rows : int | Callable [[parameter_types .TableFunctionFlightInfo ], int ] = - 1
46
43
@@ -578,6 +575,11 @@ def dynamic_schema_handler_output_schema(
578
575
return parameters .schema
579
576
580
577
578
+ def in_out_long_schema_handler (parameters : pa .RecordBatch , input_schema : pa .Schema | None = None ) -> pa .Schema :
579
+ assert input_schema is not None
580
+ return pa .schema ([input_schema .field (0 )])
581
+
582
+
581
583
def in_out_schema_handler (parameters : pa .RecordBatch , input_schema : pa .Schema | None = None ) -> pa .Schema :
582
584
assert input_schema is not None
583
585
return pa .schema ([parameters .schema .field (0 ), input_schema .field (0 )])
@@ -613,35 +615,43 @@ def in_out_echo_handler(
613
615
def in_out_wide_handler (
614
616
parameters : parameter_types .TableFunctionParameters ,
615
617
output_schema : pa .Schema ,
616
- ) -> Generator [ pa . RecordBatch , pa . RecordBatch , None ] :
618
+ ) -> parameter_types . TableFunctionInOutGenerator :
617
619
result = output_schema .empty_table ()
618
620
619
621
while True :
620
- input_chunk = yield result
622
+ input_chunk = yield ( result , True )
621
623
622
624
if input_chunk is None :
623
625
break
624
626
627
+ if isinstance (input_chunk , bool ):
628
+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
629
+
630
+ chunk_length = len (input_chunk )
631
+
625
632
result = pa .RecordBatch .from_arrays (
626
- [[i ] * len ( input_chunk ) for i in range (20 )],
633
+ [[i ] * chunk_length for i in range (20 )],
627
634
schema = output_schema ,
628
635
)
629
636
630
- return
637
+ return None
631
638
632
639
633
640
def in_out_handler (
634
641
parameters : parameter_types .TableFunctionParameters ,
635
642
output_schema : pa .Schema ,
636
- ) -> Generator [ pa . RecordBatch , pa . RecordBatch , None ] :
643
+ ) -> parameter_types . TableFunctionInOutGenerator :
637
644
result = output_schema .empty_table ()
638
645
639
646
while True :
640
- input_chunk = yield result
647
+ input_chunk = yield ( result , True )
641
648
642
649
if input_chunk is None :
643
650
break
644
651
652
+ if isinstance (input_chunk , bool ):
653
+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
654
+
645
655
assert parameters .parameters is not None
646
656
parameter_value = parameters .parameters .column (0 ).to_pylist ()[0 ]
647
657
@@ -654,7 +664,75 @@ def in_out_handler(
654
664
schema = output_schema ,
655
665
)
656
666
657
- return pa .RecordBatch .from_arrays ([["last" ], ["row" ]], schema = output_schema )
667
+ return [pa .RecordBatch .from_arrays ([["last" ], ["row" ]], schema = output_schema )]
668
+
669
+
670
+ def in_out_long_handler (
671
+ parameters : parameter_types .TableFunctionParameters ,
672
+ output_schema : pa .Schema ,
673
+ ) -> parameter_types .TableFunctionInOutGenerator :
674
+ result = output_schema .empty_table ()
675
+
676
+ while True :
677
+ input_chunk = yield (result , True )
678
+
679
+ if input_chunk is None :
680
+ break
681
+
682
+ if isinstance (input_chunk , bool ):
683
+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
684
+
685
+ # Return the input chunk ten times.
686
+ multiplier = 10
687
+ copied_results = [
688
+ pa .RecordBatch .from_arrays (
689
+ [
690
+ input_chunk .column (0 ),
691
+ ],
692
+ schema = output_schema ,
693
+ )
694
+ for index in range (multiplier )
695
+ ]
696
+
697
+ for item in copied_results [0 :- 1 ]:
698
+ yield (item , False )
699
+ result = copied_results [- 1 ]
700
+
701
+ return None
702
+
703
+
704
+ def in_out_huge_chunk_handler (
705
+ parameters : parameter_types .TableFunctionParameters ,
706
+ output_schema : pa .Schema ,
707
+ ) -> parameter_types .TableFunctionInOutGenerator :
708
+ result = output_schema .empty_table ()
709
+ multiplier = 10
710
+ chunk_length = 5000
711
+
712
+ while True :
713
+ input_chunk = yield (result , True )
714
+
715
+ if input_chunk is None :
716
+ break
717
+
718
+ if isinstance (input_chunk , bool ):
719
+ raise NotImplementedError ("Not expecting continuing output for input chunk." )
720
+
721
+ for index , _i in enumerate (range (multiplier )):
722
+ output = pa .RecordBatch .from_arrays (
723
+ [list (range (chunk_length )), list ([index ] * chunk_length )],
724
+ schema = output_schema ,
725
+ )
726
+ if index < multiplier - 1 :
727
+ yield (output , False )
728
+ else :
729
+ result = output
730
+
731
+ # test big chunks returned as the last results.
732
+ return [
733
+ pa .RecordBatch .from_arrays ([list (range (chunk_length )), list ([footer_id ] * chunk_length )], schema = output_schema )
734
+ for footer_id in (- 1 , - 2 , - 3 )
735
+ ]
658
736
659
737
660
738
def yellow_taxi_endpoint_generator (ticket_data : Any ) -> list [flight .FlightEndpoint ]:
@@ -739,6 +817,17 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
739
817
table_functions_by_name = CaseInsensitiveDict (),
740
818
tables_by_name = CaseInsensitiveDict (
741
819
{
820
+ "big_chunk" : TableInfo (
821
+ table_versions = [
822
+ pa .Table .from_arrays (
823
+ [
824
+ list (range (100000 )),
825
+ ],
826
+ schema = pa .schema ([pa .field ("id" , pa .int64 ())]),
827
+ )
828
+ ],
829
+ row_id_counter = 0 ,
830
+ ),
742
831
"employees" : TableInfo (
743
832
table_versions = [
744
833
pa .Table .from_arrays (
@@ -775,7 +864,7 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
775
864
)
776
865
],
777
866
row_id_counter = 2 ,
778
- )
867
+ ),
779
868
}
780
869
),
781
870
)
@@ -948,6 +1037,41 @@ def collatz_steps(n: int) -> list[int]:
948
1037
),
949
1038
handler = in_out_handler ,
950
1039
),
1040
+ "test_table_in_out_long" : TableFunction (
1041
+ input_schema = pa .schema (
1042
+ [
1043
+ pa .field (
1044
+ "table_input" ,
1045
+ pa .string (),
1046
+ metadata = {"is_table_type" : "1" },
1047
+ ),
1048
+ ]
1049
+ ),
1050
+ output_schema_source = TableFunctionDynamicOutput (
1051
+ schema_creator = in_out_long_schema_handler ,
1052
+ default_values = (
1053
+ pa .RecordBatch .from_arrays (
1054
+ [pa .array ([1 ], type = pa .int32 ())],
1055
+ schema = pa .schema ([pa .field ("input" , pa .int32 ())]),
1056
+ ),
1057
+ pa .schema ([pa .field ("input" , pa .int32 ())]),
1058
+ ),
1059
+ ),
1060
+ handler = in_out_long_handler ,
1061
+ ),
1062
+ "test_table_in_out_huge" : TableFunction (
1063
+ input_schema = pa .schema (
1064
+ [
1065
+ pa .field (
1066
+ "table_input" ,
1067
+ pa .string (),
1068
+ metadata = {"is_table_type" : "1" },
1069
+ ),
1070
+ ]
1071
+ ),
1072
+ output_schema_source = pa .schema ([("multiplier" , pa .int64 ()), ("value" , pa .int64 ())]),
1073
+ handler = in_out_huge_chunk_handler ,
1074
+ ),
951
1075
"test_table_in_out_wide" : TableFunction (
952
1076
input_schema = pa .schema (
953
1077
[
0 commit comments