@@ -44,7 +44,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):
44
44
45
45
def _prod (group_idx , a , size , fill_value , dtype = None ):
46
46
dtype = minimum_dtype_scalar (fill_value , dtype , a )
47
- ret = np .full (size , fill_value , dtype = dtype )
47
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
48
48
if fill_value != 1 :
49
49
ret [group_idx ] = 1 # product starts from 1
50
50
np .multiply .at (ret , group_idx , a )
@@ -57,7 +57,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):
57
57
58
58
def _last (group_idx , a , size , fill_value , dtype = None ):
59
59
dtype = minimum_dtype (fill_value , dtype or a .dtype )
60
- ret = np .full (size , fill_value , dtype = dtype )
60
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
61
61
# repeated indexing gives last value, see:
62
62
# the phrase "leaving behind the last value" on this page:
63
63
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
@@ -67,14 +67,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):
67
67
68
68
def _first (group_idx , a , size , fill_value , dtype = None ):
69
69
dtype = minimum_dtype (fill_value , dtype or a .dtype )
70
- ret = np .full (size , fill_value , dtype = dtype )
70
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
71
71
ret [group_idx [::- 1 ]] = a [::- 1 ] # same trick as _last, but in reverse
72
72
return ret
73
73
74
74
75
75
def _all (group_idx , a , size , fill_value , dtype = None ):
76
76
check_boolean (fill_value )
77
- ret = np .full (size , fill_value , dtype = bool )
77
+ ret = np .full (size , fill_value , dtype = bool , like = a )
78
78
if not fill_value :
79
79
ret [group_idx ] = True
80
80
ret [group_idx .compress (np .logical_not (a ))] = False
@@ -83,7 +83,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):
83
83
84
84
def _any (group_idx , a , size , fill_value , dtype = None ):
85
85
check_boolean (fill_value )
86
- ret = np .full (size , fill_value , dtype = bool )
86
+ ret = np .full (size , fill_value , dtype = bool , like = a )
87
87
if fill_value :
88
88
ret [group_idx ] = False
89
89
ret [group_idx .compress (a )] = True
@@ -93,7 +93,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
93
93
def _min (group_idx , a , size , fill_value , dtype = None ):
94
94
dtype = minimum_dtype (fill_value , dtype or a .dtype )
95
95
dmax = maxval (fill_value , dtype )
96
- ret = np .full (size , fill_value , dtype = dtype )
96
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
97
97
if fill_value != dmax :
98
98
ret [group_idx ] = dmax # min starts from maximum
99
99
np .minimum .at (ret , group_idx , a )
@@ -103,7 +103,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
103
103
def _max (group_idx , a , size , fill_value , dtype = None ):
104
104
dtype = minimum_dtype (fill_value , dtype or a .dtype )
105
105
dmin = minval (fill_value , dtype )
106
- ret = np .full (size , fill_value , dtype = dtype )
106
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
107
107
if fill_value != dmin :
108
108
ret [group_idx ] = dmin # max starts from minimum
109
109
np .maximum .at (ret , group_idx , a )
@@ -115,7 +115,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
115
115
group_max = _max (group_idx , a_ , size , np .nan )
116
116
# nan should never be maximum, so use a and not a_
117
117
is_max = a == group_max [group_idx ]
118
- ret = np .full (size , fill_value , dtype = dtype )
118
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
119
119
group_idx_max = group_idx [is_max ]
120
120
(argmax ,) = is_max .nonzero ()
121
121
ret [group_idx_max [::- 1 ]] = argmax [
@@ -129,7 +129,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
129
129
group_min = _min (group_idx , a_ , size , np .nan )
130
130
# nan should never be minimum, so use a and not a_
131
131
is_min = a == group_min [group_idx ]
132
- ret = np .full (size , fill_value , dtype = dtype )
132
+ ret = np .full (size , fill_value , dtype = dtype , like = a )
133
133
group_idx_min = group_idx [is_min ]
134
134
(argmin ,) = is_min .nonzero ()
135
135
ret [group_idx_min [::- 1 ]] = argmin [
@@ -144,7 +144,7 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
144
144
counts = np .bincount (group_idx , minlength = size )
145
145
if iscomplexobj (a ):
146
146
dtype = a .dtype # TODO: this is a bit clumsy
147
- sums = np .empty (size , dtype = dtype )
147
+ sums = np .empty (size , dtype = dtype , like = a )
148
148
sums .real = np .bincount (group_idx , weights = a .real , minlength = size )
149
149
sums .imag = np .bincount (group_idx , weights = a .imag , minlength = size )
150
150
else :
0 commit comments