Skip to content

Commit 6ce2f37

Browse files
committed
add in-place modulo operator
1 parent 9f71594 commit 6ce2f37

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

code/ndarray.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,12 @@ mp_obj_t ndarray_binary_op(mp_binary_op_t _op, mp_obj_t lobj, mp_obj_t robj) {
16481648
return ndarray_inplace_ams(lhs, rhs, rstrides, op);
16491649
break;
16501650
#endif
1651+
#if NDARRAY_HAS_INPLACE_MODULO
1652+
case MP_BINARY_OP_INPLACE_MODULO:
1653+
COMPLEX_DTYPE_NOT_IMPLEMENTED(lhs->dtype);
1654+
return ndarray_inplace_modulo(lhs, rhs, rstrides);
1655+
break;
1656+
#endif
16511657
#if NDARRAY_HAS_INPLACE_MULTIPLY
16521658
case MP_BINARY_OP_INPLACE_MULTIPLY:
16531659
COMPLEX_DTYPE_NOT_IMPLEMENTED(lhs->dtype);

code/ndarray_operators.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,29 @@ mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rs
11731173
}
11741174
#endif /* NDARRAY_HAS_INPLACE_ADD || NDARRAY_HAS_INPLACE_MULTIPLY || NDARRAY_HAS_INPLACE_SUBTRACT */
11751175

1176+
1177+
#if NDARRAY_HAS_INPLACE_MODULO
1178+
mp_obj_t ndarray_inplace_modulo(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rstrides) {
1179+
if((lhs->dtype != NDARRAY_FLOAT) && (rhs->dtype == NDARRAY_FLOAT)) {
1180+
mp_raise_TypeError(MP_ERROR_TEXT("results cannot be cast to specified type"));
1181+
}
1182+
if(lhs->dtype == NDARRAY_FLOAT) {
1183+
if(rhs->dtype == NDARRAY_UINT8) {
1184+
INLINE_MODULO_FLOAT_LOOP(lhs, uint8_t, larray, rarray, rstrides);
1185+
} else if(rhs->dtype == NDARRAY_UINT8) {
1186+
INLINE_MODULO_FLOAT_LOOP(lhs, int8_t, larray, rarray, rstrides);
1187+
} else if(rhs->dtype == NDARRAY_UINT16) {
1188+
INLINE_MODULO_FLOAT_LOOP(lhs, uint16_t, larray, rarray, rstrides);
1189+
} else if(rhs->dtype == NDARRAY_INT16) {
1190+
INLINE_MODULO_FLOAT_LOOP(lhs, int16_t, larray, rarray, rstrides);
1191+
} else {
1192+
INLINE_MODULO_FLOAT_LOOP(lhs, mp_float_t, larray, rarray, rstrides);
1193+
}
1194+
}
1195+
return MP_OBJ_FROM_PTR(lhs);
1196+
}
1197+
#endif /* NDARRAY_HAS_INPLACE_MODULO */
1198+
11761199
#if NDARRAY_HAS_INPLACE_TRUE_DIVIDE
11771200
mp_obj_t ndarray_inplace_divide(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rstrides) {
11781201

code/ndarray_operators.h

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mp_obj_t ndarray_binary_logical(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size
2222
mp_obj_t ndarray_binary_floor_divide(ndarray_obj_t *, ndarray_obj_t *, uint8_t , size_t *, int32_t *, int32_t *);
2323

2424
mp_obj_t ndarray_inplace_ams(ndarray_obj_t *, ndarray_obj_t *, int32_t *, uint8_t );
25+
mp_obj_t ndarray_inplace_modulo(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
2526
mp_obj_t ndarray_inplace_power(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
2627
mp_obj_t ndarray_inplace_divide(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
2728

@@ -624,4 +625,90 @@ mp_obj_t ndarray_inplace_divide(ndarray_obj_t *, ndarray_obj_t *, int32_t *);
624625
j++;\
625626
} while(j < (results)->shape[ULAB_MAX_DIMS - 4]);\
626627
} while(0)
628+
#endif /* ULAB_MAX_DIMS == 4 */
629+
630+
631+
#define INPLACE_MODULO_FLOAT1(results, type_right, larray, rarray, rstrides)\
632+
({\
633+
size_t l = 0;\
634+
do {\
635+
*((mp_float_t *)larray) = MICROPY_FLOAT_C_FUN(fmod)(*((mp_float_t *)(larray)), *((type_right *)(rarray)));\
636+
(larray) += (results)->strides[ULAB_MAX_DIMS - 1];\
637+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 1];\
638+
l++;\
639+
} while(l < (results)->shape[ULAB_MAX_DIMS - 1]);\
640+
})
641+
642+
643+
#if ULAB_MAX_DIMS == 1
644+
#define INPLACE_MODULO_FLOAT_LOOP(results, type_right, larray, rarray, rstrides) do {\
645+
INPLACE_MODULO_FLOAT1((results), type_right, (larray), (rarray), (rstrides));\
646+
} while(0)
647+
#endif /* ULAB_MAX_DIMS == 1 */
648+
649+
650+
#if ULAB_MAX_DIMS == 2
651+
#define INLINE_MODULO_FLOAT_LOOP(results, type_right, larray, rarray, rstrides) do {\
652+
size_t l = 0;\
653+
do {\
654+
INPLACE_MODULO_FLOAT1((results), type_right, (larray), (rarray), (rstrides));\
655+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
656+
(larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
657+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
658+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
659+
l++;\
660+
} while(l < (results)->shape[ULAB_MAX_DIMS - 2]);\
661+
} while(0)
662+
#endif /* ULAB_MAX_DIMS == 2 */
663+
664+
#if ULAB_MAX_DIMS == 3
665+
#define INLINE_MODULO_FLOAT_LOOP(results, type_right, larray, rarray, rstrides) do {\
666+
size_t k = 0;\
667+
do {\
668+
size_t l = 0;\
669+
do {\
670+
INPLACE_MODULO_FLOAT1((results), type_right, (larray), (rarray), (rstrides));\
671+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
672+
(larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
673+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
674+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
675+
l++;\
676+
} while(l < (results)->shape[ULAB_MAX_DIMS - 2]);\
677+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS - 2];\
678+
(larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
679+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS - 2];\
680+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
681+
k++;\
682+
} while(k < (results)->shape[ULAB_MAX_DIMS - 3]);\
683+
} while(0)
684+
#endif /* ULAB_MAX_DIMS == 3 */
685+
686+
#if ULAB_MAX_DIMS == 4
687+
#define INLINE_MODULO_FLOAT_LOOP(results, type_right, larray, rarray, rstrides) do {\
688+
size_t j = 0;\
689+
do {\
690+
size_t k = 0;\
691+
do {\
692+
size_t l = 0;\
693+
do {\
694+
INPLACE_MODULO_FLOAT1((results), type_right, (larray), (rarray), (rstrides));\
695+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
696+
(larray) += (results)->strides[ULAB_MAX_DIMS - 2];\
697+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 1] * (results)->shape[ULAB_MAX_DIMS - 1];\
698+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 2];\
699+
l++;\
700+
} while(l < (results)->shape[ULAB_MAX_DIMS - 2]);\
701+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS - 2];\
702+
(larray) += (results)->strides[ULAB_MAX_DIMS - 3];\
703+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 2] * (results)->shape[ULAB_MAX_DIMS - 2];\
704+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 3];\
705+
k++;\
706+
} while(k < (results)->shape[ULAB_MAX_DIMS - 3]);\
707+
(larray) -= (results)->strides[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS - 3];\
708+
(larray) += (results)->strides[ULAB_MAX_DIMS - 4];\
709+
(rarray) -= (rstrides)[ULAB_MAX_DIMS - 3] * (results)->shape[ULAB_MAX_DIMS - 3];\
710+
(rarray) += (rstrides)[ULAB_MAX_DIMS - 4];\
711+
j++;\
712+
} while(j < (results)->shape[ULAB_MAX_DIMS - 4]);\
713+
} while(0)
627714
#endif /* ULAB_MAX_DIMS == 4 */

code/ulab.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@
165165
#define NDARRAY_HAS_INPLACE_ADD (1)
166166
#endif
167167

168+
#ifndef NDARRAY_HAS_INPLACE_MODULO
169+
#define NDARRAY_HAS_INPLACE_MODU (1)
170+
#endif
171+
168172
#ifndef NDARRAY_HAS_INPLACE_MULTIPLY
169173
#define NDARRAY_HAS_INPLACE_MULTIPLY (1)
170174
#endif

docs/ulab-ndarray.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2599,7 +2599,7 @@
25992599
"source": [
26002600
"# Binary operators\n",
26012601
"\n",
2602-
"`ulab` implements the `+`, `-`, `*`, `/`, `**`, `%`, `<`, `>`, `<=`, `>=`, `==`, `!=`, `+=`, `-=`, `*=`, `/=`, `**=` binary operators, as well as the `AND`, `OR`, `XOR` bit-wise operators that work element-wise. Note that the bit-wise operators will raise an exception, if either of the operands is of `float` or `complex` type.\n",
2602+
"`ulab` implements the `+`, `-`, `*`, `/`, `**`, `%`, `<`, `>`, `<=`, `>=`, `==`, `!=`, `+=`, `-=`, `*=`, `/=`, `**=`, `%=` binary operators, as well as the `AND`, `OR`, `XOR` bit-wise operators that work element-wise. Note that the bit-wise operators will raise an exception, if either of the operands is of `float` or `complex` type.\n",
26032603
"\n",
26042604
"Broadcasting is available, meaning that the two operands do not even have to have the same shape. If the lengths along the respective axes are equal, or one of them is 1, or the axis is missing, the element-wise operation can still be carried out. \n",
26052605
"A thorough explanation of broadcasting can be found under https://numpy.org/doc/stable/user/basics.broadcasting.html. \n",

0 commit comments

Comments
 (0)