Skip to content

Commit 261e606

Browse files
committed
implement axis keyword of transpose
1 parent 068da5f commit 261e606

File tree

3 files changed

+99
-10
lines changed

3 files changed

+99
-10
lines changed

code/ndarray.c

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,28 +1874,110 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
18741874
#endif /* NDARRAY_HAS_UNARY_OPS */
18751875

18761876
#if NDARRAY_HAS_TRANSPOSE
1877-
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
1878-
#if ULAB_MAX_DIMS == 1
1879-
return self_in;
1880-
#endif
1881-
// TODO: check, what happens to the offset here, if we have a view
1877+
// We have to implement the T property separately, for the property can't take keyword arguments
1878+
1879+
#if ULAB_MAX_DIMS == 1
1880+
// isolating the one-dimensional case saves space, because the transpose is sort of meaningless
1881+
mp_obj_t ndarray_T(mp_obj_t self_in) {
1882+
return self_in;
1883+
}
1884+
#else
1885+
mp_obj_t ndarray_T(mp_obj_t self_in) {
1886+
// without argument, simply return a view with axes in reverse order
18821887
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
18831888
if(self->ndim == 1) {
18841889
return self_in;
18851890
}
18861891
size_t *shape = m_new(size_t, self->ndim);
18871892
int32_t *strides = m_new(int32_t, self->ndim);
1888-
for(uint8_t i=0; i < self->ndim; i++) {
1893+
for(uint8_t i = 0; i < self->ndim; i++) {
18891894
shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i];
18901895
strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i];
18911896
}
1892-
// TODO: I am not sure ndarray_new_view is OK here...
1893-
// should be deep copy...
18941897
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
18951898
return MP_OBJ_FROM_PTR(ndarray);
18961899
}
1900+
#endif /* ULAB_MAX_DIMS == 1 */
1901+
1902+
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T);
1903+
1904+
# if ULAB_MAX_DIMS == 1
1905+
// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed...
1906+
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
1907+
return self_in;
1908+
}
18971909

18981910
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
1911+
#else
1912+
mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
1913+
static const mp_arg_t allowed_args[] = {
1914+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
1915+
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
1916+
};
1917+
1918+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
1919+
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
1920+
1921+
ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
1922+
1923+
if(self->ndim == 1) {
1924+
return args[0].u_obj;
1925+
}
1926+
1927+
size_t *shape = m_new(size_t, self->ndim);
1928+
int32_t *strides = m_new(int32_t, self->ndim);
1929+
uint8_t *order = m_new(uint8_t, self->ndim);
1930+
1931+
mp_obj_t axis = args[1].u_obj;
1932+
1933+
if(axis == mp_const_none) {
1934+
// simply swap the order of the axes
1935+
for(uint8_t i = 0; i < self->ndim; i++) {
1936+
order[i] = self->ndim - 1 - i;
1937+
}
1938+
} else {
1939+
if(!mp_obj_is_type(axis, &mp_type_tuple)) {
1940+
mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers"));
1941+
}
1942+
// start with the straight array, and then swap only those specified in the argument
1943+
for(uint8_t i = 0; i < self->ndim; i++) {
1944+
order[i] = i;
1945+
}
1946+
1947+
mp_obj_tuple_t *axes = MP_OBJ_TO_PTR(axis);
1948+
1949+
if(axes->len > self->ndim) {
1950+
mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified"));
1951+
}
1952+
1953+
for(uint8_t i = 0; i < axes->len; i++) {
1954+
int32_t ax = mp_obj_get_int(axes->items[i]);
1955+
if((ax >= self->ndim) || (ax < 0)) {
1956+
mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds"));
1957+
} else {
1958+
order[i] = (uint8_t)ax;
1959+
// TODO: check that no two identical numbers appear in the tuple
1960+
for(uint8_t j = 0; j < i; j++) {
1961+
if(order[i] == order[j]) {
1962+
mp_raise_ValueError(MP_ERROR_TEXT("repeated indices"));
1963+
}
1964+
}
1965+
}
1966+
}
1967+
}
1968+
1969+
uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim;
1970+
for(uint8_t i = 0; i < self->ndim; i++) {
1971+
shape[axis_offset + i] = self->shape[axis_offset + order[i]];
1972+
strides[axis_offset + i] = self->strides[axis_offset + order[i]];
1973+
}
1974+
1975+
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
1976+
return MP_OBJ_FROM_PTR(ndarray);
1977+
}
1978+
1979+
MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose);
1980+
#endif /* ULAB_MAX_DIMS == 1 */
18991981
#endif /* NDARRAY_HAS_TRANSPOSE */
19001982

19011983
#if ULAB_MAX_DIMS > 1

code/ndarray.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj);
265265
#endif
266266

267267
#if NDARRAY_HAS_TRANSPOSE
268+
mp_obj_t ndarray_T(mp_obj_t );
269+
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj);
270+
#if ULAB_MAX_DIMS == 1
268271
mp_obj_t ndarray_transpose(mp_obj_t );
269272
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj);
270-
#endif
273+
#else
274+
mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *);
275+
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj);
276+
#endif /* ULAB_MAX_DIMS == 1 */
277+
#endif /* NDARRAY_HAS_TRANSPOSE */
271278

272279
#if ULAB_NUMPY_HAS_NDINFO
273280
mp_obj_t ndarray_info(mp_obj_t );

code/ndarray_properties.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
6464
#endif
6565
#if NDARRAY_HAS_TRANSPOSE
6666
case MP_QSTR_T:
67-
dest[0] = ndarray_transpose(self_in);
67+
dest[0] = ndarray_T(self_in);
6868
break;
6969
#endif
7070
#if ULAB_SUPPORTS_COMPLEX

0 commit comments

Comments
 (0)