Skip to content

Commit d148074

Browse files
committed
More like arguments
1 parent 934a248 commit d148074

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

numpy_groupies/aggregate_numpy.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
4444

4545
def _prod(group_idx, a, size, fill_value, dtype=None):
4646
dtype = minimum_dtype_scalar(fill_value, dtype, a)
47-
ret = np.full(size, fill_value, dtype=dtype)
47+
ret = np.full(size, fill_value, dtype=dtype, like=a)
4848
if fill_value != 1:
4949
ret[group_idx] = 1 # product starts from 1
5050
np.multiply.at(ret, group_idx, a)
@@ -57,7 +57,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):
5757

5858
def _last(group_idx, a, size, fill_value, dtype=None):
5959
dtype = minimum_dtype(fill_value, dtype or a.dtype)
60-
ret = np.full(size, fill_value, dtype=dtype)
60+
ret = np.full(size, fill_value, dtype=dtype, like=a)
6161
# repeated indexing gives last value, see:
6262
# the phrase "leaving behind the last value" on this page:
6363
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
@@ -67,14 +67,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):
6767

6868
def _first(group_idx, a, size, fill_value, dtype=None):
6969
dtype = minimum_dtype(fill_value, dtype or a.dtype)
70-
ret = np.full(size, fill_value, dtype=dtype)
70+
ret = np.full(size, fill_value, dtype=dtype, like=a)
7171
ret[group_idx[::-1]] = a[::-1] # same trick as _last, but in reverse
7272
return ret
7373

7474

7575
def _all(group_idx, a, size, fill_value, dtype=None):
7676
check_boolean(fill_value)
77-
ret = np.full(size, fill_value, dtype=bool)
77+
ret = np.full(size, fill_value, dtype=bool, like=a)
7878
if not fill_value:
7979
ret[group_idx] = True
8080
ret[group_idx.compress(np.logical_not(a))] = False
@@ -83,7 +83,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):
8383

8484
def _any(group_idx, a, size, fill_value, dtype=None):
8585
check_boolean(fill_value)
86-
ret = np.full(size, fill_value, dtype=bool)
86+
ret = np.full(size, fill_value, dtype=bool, like=a)
8787
if fill_value:
8888
ret[group_idx] = False
8989
ret[group_idx.compress(a)] = True
@@ -93,7 +93,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
9393
def _min(group_idx, a, size, fill_value, dtype=None):
9494
dtype = minimum_dtype(fill_value, dtype or a.dtype)
9595
dmax = maxval(fill_value, dtype)
96-
ret = np.full(size, fill_value, dtype=dtype)
96+
ret = np.full(size, fill_value, dtype=dtype, like=a)
9797
if fill_value != dmax:
9898
ret[group_idx] = dmax # min starts from maximum
9999
np.minimum.at(ret, group_idx, a)
@@ -103,7 +103,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
103103
def _max(group_idx, a, size, fill_value, dtype=None):
104104
dtype = minimum_dtype(fill_value, dtype or a.dtype)
105105
dmin = minval(fill_value, dtype)
106-
ret = np.full(size, fill_value, dtype=dtype)
106+
ret = np.full(size, fill_value, dtype=dtype, like=a)
107107
if fill_value != dmin:
108108
ret[group_idx] = dmin # max starts from minimum
109109
np.maximum.at(ret, group_idx, a)
@@ -115,7 +115,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
115115
group_max = _max(group_idx, a_, size, np.nan)
116116
# nan should never be maximum, so use a and not a_
117117
is_max = a == group_max[group_idx]
118-
ret = np.full(size, fill_value, dtype=dtype)
118+
ret = np.full(size, fill_value, dtype=dtype, like=a)
119119
group_idx_max = group_idx[is_max]
120120
(argmax,) = is_max.nonzero()
121121
ret[group_idx_max[::-1]] = argmax[
@@ -129,7 +129,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129129
group_min = _min(group_idx, a_, size, np.nan)
130130
# nan should never be minimum, so use a and not a_
131131
is_min = a == group_min[group_idx]
132-
ret = np.full(size, fill_value, dtype=dtype)
132+
ret = np.full(size, fill_value, dtype=dtype, like=a)
133133
group_idx_min = group_idx[is_min]
134134
(argmin,) = is_min.nonzero()
135135
ret[group_idx_min[::-1]] = argmin[
@@ -144,7 +144,7 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
144144
counts = np.bincount(group_idx, minlength=size)
145145
if iscomplexobj(a):
146146
dtype = a.dtype # TODO: this is a bit clumsy
147-
sums = np.empty(size, dtype=dtype)
147+
sums = np.empty(size, dtype=dtype, like=a)
148148
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
149149
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
150150
else:

0 commit comments

Comments
 (0)