Skip to content

Commit 03c8655

Browse files
committed
integer arrays can be binned
1 parent 13974df commit 03c8655

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

code/numpy/compare.c

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
*
77
* The MIT License (MIT)
88
*
9-
* Copyright (c) 2020-2021 Zoltán Vörös
9+
* Copyright (c) 2020-2025 Zoltán Vörös
1010
* 2020 Jeff Epler for Adafruit Industries
1111
*/
1212

@@ -27,14 +27,69 @@
2727
mp_obj_t compare_bincount(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
2828
static const mp_arg_t allowed_args[] = {
2929
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
30-
{ MP_QSTR_weights, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
31-
{ MP_QSTR_minlength, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
30+
{ MP_QSTR_weights, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } },
31+
{ MP_QSTR_minlength, MP_ARG_OBJ | MP_ARG_KW_ONLY, { .u_rom_obj = MP_ROM_NONE } },
3232
};
3333

3434
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
3535
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
3636

37-
return mp_const_none;
37+
if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
38+
mp_raise_TypeError(MP_ERROR_TEXT("input must be an ndarray"));
39+
}
40+
ndarray_obj_t *input = MP_OBJ_TO_PTR(args[0].u_obj);
41+
42+
#if ULAB_MAX_DIMS > 1
43+
// no need to check anything, if the maximum number of dimensions is 1
44+
if(input->ndim != 1) {
45+
mp_raise_ValueError(MP_ERROR_TEXT("object too deep for desired arrayy"));
46+
}
47+
#endif
48+
if((input->dtype != NDARRAY_UINT8) && (input->dtype != NDARRAY_UINT16)) {
49+
mp_raise_TypeError(MP_ERROR_TEXT("cannot cast array data from dtype"));
50+
}
51+
52+
// first find the maximum of the array, and figure out how long the result should be
53+
uint16_t max = 0;
54+
int32_t stride = input->strides[ULAB_MAX_DIMS - 1];
55+
if(input->dtype == NDARRAY_UINT8) {
56+
uint8_t *iarray = (uint8_t *)input->array;
57+
for(size_t i = 0; i < input->len; i++) {
58+
if(*iarray > max) {
59+
max = *iarray;
60+
}
61+
iarray += stride;
62+
}
63+
} else if(input->dtype == NDARRAY_UINT16) {
64+
stride /= 2;
65+
uint16_t *iarray = (uint16_t *)input->array;
66+
for(size_t i = 0; i < input->len; i++) {
67+
if(*iarray > max) {
68+
max = *iarray;
69+
}
70+
iarray += stride;
71+
}
72+
}
73+
ndarray_obj_t *result = ndarray_new_linear_array(max + 1, NDARRAY_UINT16);
74+
75+
// now we can do the binning
76+
uint16_t *rarray = (uint16_t *)result->array;
77+
78+
if(input->dtype == NDARRAY_UINT8) {
79+
uint8_t *iarray = (uint8_t *)input->array;
80+
for(size_t i = 0; i < input->len; i++) {
81+
rarray[*iarray] += 1;
82+
iarray += stride;
83+
}
84+
} else if(input->dtype == NDARRAY_UINT16) {
85+
uint16_t *iarray = (uint16_t *)input->array;
86+
for(size_t i = 0; i < input->len; i++) {
87+
rarray[*iarray] += 1;
88+
iarray += stride;
89+
}
90+
}
91+
92+
return MP_OBJ_FROM_PTR(result);
3893
}
3994

4095
MP_DEFINE_CONST_FUN_OBJ_KW(compare_bincount_obj, 1, compare_bincount);

code/numpy/compare.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
*
77
* The MIT License (MIT)
88
*
9-
* Copyright (c) 2020-2021 Zoltán Vörös
9+
* Copyright (c) 2020-2025 Zoltán Vörös
1010
*/
1111

1212
#ifndef _COMPARE_

0 commit comments

Comments
 (0)