Skip to content

Commit ec7caa8

Browse files
committed
add keyword handling
1 parent 03c8655 commit ec7caa8

File tree

1 file changed

+74
-19
lines changed

1 file changed

+74
-19
lines changed

code/numpy/compare.c

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,53 +42,108 @@ mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_
4242
#if ULAB_MAX_DIMS > 1
4343
// no need to check anything, if the maximum number of dimensions is 1
4444
if(input->ndim != 1) {
45-
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired arrayy"));
45+
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired array"));
4646
}
4747
#endif
4848
if((input->dtype != NDARRAY_UINT8) && (input->dtype != NDARRAY_UINT16)) {
4949
mp_raise_TypeError(MP_ERROR_TEXT("cannot cast array data from dtype"));
5050
}
5151

5252
// first find the maximum of the array, and figure out how long the result should be
53-
uint16_t max = 0;
53+
size_t length = 0;
5454
int32_t stride = input->strides[ULAB_MAX_DIMS - 1];
5555
if(input->dtype == NDARRAY_UINT8) {
5656
uint8_t *iarray = (uint8_t *)input->array;
5757
for(size_t i = 0; i < input->len; i++) {
58-
if(*iarray > max) {
59-
max = *iarray;
58+
if(*iarray > length) {
59+
length = *iarray;
6060
}
6161
iarray += stride;
6262
}
6363
} else if(input->dtype == NDARRAY_UINT16) {
6464
stride /= 2;
6565
uint16_t *iarray = (uint16_t *)input->array;
6666
for(size_t i = 0; i < input->len; i++) {
67-
if(*iarray > max) {
68-
max = *iarray;
67+
if(*iarray > length) {
68+
length = *iarray;
6969
}
7070
iarray += stride;
7171
}
7272
}
73-
ndarray_obj_t *result = ndarray_new_linear_array(max + 1, NDARRAY_UINT16);
73+
length += 1;
7474

75-
// now we can do the binning
76-
uint16_t *rarray = (uint16_t *)result->array;
75+
if(args[2].u_obj != mp_const_none) {
76+
int32_t minlength = mp_obj_get_int(args[2].u_obj);
77+
if(minlength < 0) {
78+
mp_raise_ValueError(MP_ERROR_TEXT("minlength must not be negative"));
79+
}
80+
if((size_t)minlength > length) {
81+
length = minlength;
82+
}
83+
}
7784

78-
if(input->dtype == NDARRAY_UINT8) {
79-
uint8_t *iarray = (uint8_t *)input->array;
80-
for(size_t i = 0; i < input->len; i++) {
81-
rarray[*iarray] += 1;
82-
iarray += stride;
85+
ndarray_obj_t *result = NULL;
86+
ndarray_obj_t *weights = NULL;
87+
88+
if(args[1].u_obj == mp_const_none) {
89+
result = ndarray_new_linear_array(length, NDARRAY_UINT16);
90+
} else {
91+
if(!mp_obj_is_type(args[1].u_obj, &ulab_ndarray_type)) {
92+
mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray"));
8393
}
84-
} else if(input->dtype == NDARRAY_UINT16) {
85-
uint16_t *iarray = (uint16_t *)input->array;
86-
for(size_t i = 0; i < input->len; i++) {
87-
rarray[*iarray] += 1;
88-
iarray += stride;
94+
weights = MP_OBJ_TO_PTR(args[1].u_obj);
95+
result = ndarray_new_linear_array(length, NDARRAY_FLOAT);
96+
}
97+
98+
// now we can do the binning
99+
if(result->dtype == NDARRAY_UINT16) {
100+
uint16_t *rarray = (uint16_t *)result->array;
101+
if(input->dtype == NDARRAY_UINT8) {
102+
uint8_t *iarray = (uint8_t *)input->array;
103+
for(size_t i = 0; i < input->len; i++) {
104+
rarray[*iarray] += 1;
105+
iarray += stride;
106+
}
107+
} else if(input->dtype == NDARRAY_UINT16) {
108+
uint16_t *iarray = (uint16_t *)input->array;
109+
for(size_t i = 0; i < input->len; i++) {
110+
rarray[*iarray] += 1;
111+
iarray += stride;
112+
}
113+
}
114+
} else {
115+
mp_float_t *rarray = (mp_float_t *)result->array;
116+
if(input->dtype == NDARRAY_UINT8) {
117+
uint8_t *iarray = (uint8_t *)input->array;
118+
for(size_t i = 0; i < input->len; i++) {
119+
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
120+
iarray += stride;
121+
}
122+
} else if(input->dtype == NDARRAY_UINT16) {
123+
uint16_t *iarray = (uint16_t *)input->array;
124+
for(size_t i = 0; i < input->len; i++) {
125+
rarray[*iarray] += MICROPY_FLOAT_CONST(1.0);
126+
iarray += stride;
127+
}
89128
}
90129
}
91130

131+
if(weights != NULL) {
132+
mp_float_t (*get_weights)(void *) = ndarray_get_float_function(weights->dtype);
133+
mp_float_t *rarray = (mp_float_t *)result->array;
134+
uint8_t *warray = (uint8_t *)weights->array;
135+
136+
size_t fill_length = result->len;
137+
if(weights->len < result->len) {
138+
fill_length = weights->len;
139+
}
140+
141+
for(size_t i = 0; i < fill_length; i++) {
142+
*rarray = *rarray * get_weights(warray);
143+
rarray++;
144+
warray += weights->strides[ULAB_MAX_DIMS - 1];
145+
}
146+
}
92147
return MP_OBJ_FROM_PTR(result);
93148
}
94149

0 commit comments

Comments
 (0)