@@ -921,10 +921,10 @@ def iterate_nested(nested_list):
921
921
922
922
923
923
def expand_args_to_dims (
924
- dim : ItemOrSequence [str ],
924
+ dim : ItemOrSequence [Hashable ],
925
925
arg_names : Sequence [str ],
926
926
args : Sequence [ItemOrSequence [Any ]],
927
- ) -> Tuple [Sequence [str ], Sequence [Sequence [Any ]]]:
927
+ ) -> Tuple [Sequence [Hashable ], Sequence [Sequence [Any ]]]:
928
928
"""Expand dims and all elements in args to be arrays of the length of the number of dimensions
929
929
930
930
Parameters
@@ -946,31 +946,33 @@ def expand_args_to_dims(
946
946
list of dims, list of lists of arguments
947
947
"""
948
948
if is_scalar (dim ):
949
- for name , arg in zip (arg_names , args ):
950
- if not is_scalar (arg ):
951
- raise ValueError (f"Expected { name } ={ arg !r} to be a scalar like 'dim'." )
952
- assert isinstance (dim , str ) or isinstance (dim , bytes )
953
- dim = [dim ]
949
+ dim_list : Sequence [Hashable ] = [dim ]
950
+ else :
951
+ assert isinstance (dim , list )
952
+ dim_list = dim
954
953
955
954
# dim is now a list
956
- nroll = len (dim )
955
+ nroll = len (dim_list )
957
956
958
- def to_array (arg ):
957
+ def to_list (arg ):
959
958
if is_scalar (arg ):
960
959
return [arg ] * nroll
960
+ elif isinstance (arg , list ) and len (arg ) == 1 :
961
+ return [arg [0 ]] * nroll
961
962
return arg
962
963
963
- arr_args = [to_array (arg ) for arg in args ]
964
+ arr_args = [to_list (arg ) for arg in args ]
964
965
965
- if any (len (dim ) != len (arg ) for arg in arr_args ):
966
- names_vals = ", " .join (
967
- f"{ name } ={ val !r} " for name , val in zip (arg_names , arr_args )
966
+ if any (len (arg ) != len (dim_list ) for arg in arr_args ):
967
+ names_args_len = ", " .join (
968
+ f"{ name } ={ args !r} (len={ len (args )} )" for name , args in zip (arg_names , arr_args )
969
+ if len (args ) != len (dim_list )
968
970
)
969
971
raise ValueError (
970
- "Arguments must all be the same length. " f" Received { names_vals } . "
972
+ f"Expected all arguments to have len= { len ( dim_list ) } . Received: { names_args_len } "
971
973
)
972
974
973
- return dim , arr_args
975
+ return dim_list , arr_args
974
976
975
977
976
978
def get_pads (
0 commit comments