@@ -42,53 +42,108 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
42
42
#if ULAB_MAX_DIMS > 1
43
43
// no need to check anything, if the maximum number of dimensions is 1
44
44
if (input -> ndim != 1 ) {
45
- mp_raise_ValueError (MP_ERROR_TEXT ("object too deep for desired arrayy " ));
45
+ mp_raise_ValueError (MP_ERROR_TEXT ("object too deep for desired array " ));
46
46
}
47
47
#endif
48
48
if ((input -> dtype != NDARRAY_UINT8 ) && (input -> dtype != NDARRAY_UINT16 )) {
49
49
mp_raise_TypeError (MP_ERROR_TEXT ("cannot cast array data from dtype" ));
50
50
}
51
51
52
52
// first find the maximum of the array, and figure out how long the result should be
53
- uint16_t max = 0 ;
53
+ size_t length = 0 ;
54
54
int32_t stride = input -> strides [ULAB_MAX_DIMS - 1 ];
55
55
if (input -> dtype == NDARRAY_UINT8 ) {
56
56
uint8_t * iarray = (uint8_t * )input -> array ;
57
57
for (size_t i = 0 ; i < input -> len ; i ++ ) {
58
- if (* iarray > max ) {
59
- max = * iarray ;
58
+ if (* iarray > length ) {
59
+ length = * iarray ;
60
60
}
61
61
iarray += stride ;
62
62
}
63
63
} else if (input -> dtype == NDARRAY_UINT16 ) {
64
64
stride /= 2 ;
65
65
uint16_t * iarray = (uint16_t * )input -> array ;
66
66
for (size_t i = 0 ; i < input -> len ; i ++ ) {
67
- if (* iarray > max ) {
68
- max = * iarray ;
67
+ if (* iarray > length ) {
68
+ length = * iarray ;
69
69
}
70
70
iarray += stride ;
71
71
}
72
72
}
73
- ndarray_obj_t * result = ndarray_new_linear_array ( max + 1 , NDARRAY_UINT16 ) ;
73
+ length += 1 ;
74
74
75
- // now we can do the binning
76
- uint16_t * rarray = (uint16_t * )result -> array ;
75
+ if (args [2 ].u_obj != mp_const_none ) {
76
+ int32_t minlength = mp_obj_get_int (args [2 ].u_obj );
77
+ if (minlength < 0 ) {
78
+ mp_raise_ValueError (MP_ERROR_TEXT ("minlength must not be negative" ));
79
+ }
80
+ if ((size_t )minlength > length ) {
81
+ length = minlength ;
82
+ }
83
+ }
77
84
78
- if (input -> dtype == NDARRAY_UINT8 ) {
79
- uint8_t * iarray = (uint8_t * )input -> array ;
80
- for (size_t i = 0 ; i < input -> len ; i ++ ) {
81
- rarray [* iarray ] += 1 ;
82
- iarray += stride ;
85
+ ndarray_obj_t * result = NULL ;
86
+ ndarray_obj_t * weights = NULL ;
87
+
88
+ if (args [1 ].u_obj == mp_const_none ) {
89
+ result = ndarray_new_linear_array (length , NDARRAY_UINT16 );
90
+ } else {
91
+ if (!mp_obj_is_type (args [1 ].u_obj , & ulab_ndarray_type )) {
92
+ mp_raise_TypeError (MP_ERROR_TEXT ("input must be an ndarray" ));
83
93
}
84
- } else if (input -> dtype == NDARRAY_UINT16 ) {
85
- uint16_t * iarray = (uint16_t * )input -> array ;
86
- for (size_t i = 0 ; i < input -> len ; i ++ ) {
87
- rarray [* iarray ] += 1 ;
88
- iarray += stride ;
94
+ weights = MP_OBJ_TO_PTR (args [1 ].u_obj );
95
+ result = ndarray_new_linear_array (length , NDARRAY_FLOAT );
96
+ }
97
+
98
+ // now we can do the binning
99
+ if (result -> dtype == NDARRAY_UINT16 ) {
100
+ uint16_t * rarray = (uint16_t * )result -> array ;
101
+ if (input -> dtype == NDARRAY_UINT8 ) {
102
+ uint8_t * iarray = (uint8_t * )input -> array ;
103
+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
104
+ rarray [* iarray ] += 1 ;
105
+ iarray += stride ;
106
+ }
107
+ } else if (input -> dtype == NDARRAY_UINT16 ) {
108
+ uint16_t * iarray = (uint16_t * )input -> array ;
109
+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
110
+ rarray [* iarray ] += 1 ;
111
+ iarray += stride ;
112
+ }
113
+ }
114
+ } else {
115
+ mp_float_t * rarray = (mp_float_t * )result -> array ;
116
+ if (input -> dtype == NDARRAY_UINT8 ) {
117
+ uint8_t * iarray = (uint8_t * )input -> array ;
118
+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
119
+ rarray [* iarray ] += MICROPY_FLOAT_CONST (1.0 );
120
+ iarray += stride ;
121
+ }
122
+ } else if (input -> dtype == NDARRAY_UINT16 ) {
123
+ uint16_t * iarray = (uint16_t * )input -> array ;
124
+ for (size_t i = 0 ; i < input -> len ; i ++ ) {
125
+ rarray [* iarray ] += MICROPY_FLOAT_CONST (1.0 );
126
+ iarray += stride ;
127
+ }
89
128
}
90
129
}
91
130
131
+ if (weights != NULL ) {
132
+ mp_float_t (* get_weights )(void * ) = ndarray_get_float_function (weights -> dtype );
133
+ mp_float_t * rarray = (mp_float_t * )result -> array ;
134
+ uint8_t * warray = (uint8_t * )weights -> array ;
135
+
136
+ size_t fill_length = result -> len ;
137
+ if (weights -> len < result -> len ) {
138
+ fill_length = weights -> len ;
139
+ }
140
+
141
+ for (size_t i = 0 ; i < fill_length ; i ++ ) {
142
+ * rarray = * rarray * get_weights (warray );
143
+ rarray ++ ;
144
+ warray += weights -> strides [ULAB_MAX_DIMS - 1 ];
145
+ }
146
+ }
92
147
return MP_OBJ_FROM_PTR (result );
93
148
}
94
149
0 commit comments