@@ -780,6 +780,36 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
780
780
),
781
781
)
782
782
783
+
784
+ def collatz_step_count (n : int ) -> int :
785
+ steps = 0
786
+ while n != 1 :
787
+ if n % 2 == 0 :
788
+ n //= 2
789
+ else :
790
+ n = 3 * n + 1
791
+ steps += 1
792
+ return steps
793
+
794
+
795
+ def collatz (inputs : pa .Array ) -> pa .Array :
796
+ results = [collatz_step_count (n ) for n in inputs .to_pylist ()]
797
+ return pa .array (results , type = pa .int64 ())
798
+
799
+
800
+ def collatz_steps (n : int ) -> list [int ]:
801
+ steps = 0
802
+ results = []
803
+ while n != 1 :
804
+ if n % 2 == 0 :
805
+ n //= 2
806
+ else :
807
+ n = 3 * n + 1
808
+ results .append (n )
809
+ steps += 1
810
+ return results
811
+
812
+
783
813
util_schema = SchemaCollection (
784
814
scalar_functions_by_name = CaseInsensitiveDict (
785
815
{
@@ -798,6 +828,18 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
798
828
output_schema = pa .schema ([pa .field ("result" , pa .int64 ())]),
799
829
handler = add_handler ,
800
830
),
831
+ "collatz" : ScalarFunction (
832
+ input_schema = pa .schema ([pa .field ("n" , pa .int64 ())]),
833
+ output_schema = pa .schema ([pa .field ("result" , pa .int64 ())]),
834
+ handler = lambda table : collatz (table .column (0 )),
835
+ ),
836
+ "collatz_sequence" : ScalarFunction (
837
+ input_schema = pa .schema ([pa .field ("n" , pa .int64 ())]),
838
+ output_schema = pa .schema ([pa .field ("result" , pa .list_ (pa .int64 ()))]),
839
+ handler = lambda table : pa .array (
840
+ [collatz_steps (n ) for n in table .column (0 ).to_pylist ()], type = pa .list_ (pa .int64 ())
841
+ ),
842
+ ),
801
843
}
802
844
),
803
845
table_functions_by_name = CaseInsensitiveDict (
0 commit comments