Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

[MRG] stretch() improvements #185

Merged
merged 2 commits into from
May 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 46 additions & 43 deletions root_numpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
VLEN = np.vectorize(len)


def _is_object_field(arr, col):
return arr.dtype[col] == 'O'


def rec2array(rec, fields=None):
"""Convert a record array into a ndarray with a homogeneous data type.

Expand Down Expand Up @@ -72,7 +68,7 @@ def stack(recs, fields=None):
return np.hstack([rec[fields] for rec in recs])


def stretch(arr, fields):
def stretch(arr, fields=None):
"""Stretch an array.

Stretch an array by ``hstack()``-ing multiple array fields while
Expand All @@ -83,8 +79,8 @@ def stretch(arr, fields):
----------
arr : NumPy structured or record array
The array to be stretched.
fields : list of strings
A list of column names to stretch.
fields : list of strings, optional (default=None)
A list of column names to stretch. If None, then stretch all fields.

Returns
-------
Expand All @@ -103,44 +99,51 @@ def stretch(arr, fields):
dtype=[('scalar', '<i8'), ('array', '<f8')])

"""
dt = []
has_array_field = False
has_scalar_filed = False
first_array = None

# Construct dtype
for c in fields:
if _is_object_field(arr, c):
dt.append((c, arr[c][0].dtype))
has_array_field = True
first_array = c if first_array is None else first_array
else:
# Assume scalar
dt.append((c, arr[c].dtype))
has_scalar_filed = True

if not has_array_field:
raise RuntimeError("No array column specified")

len_array = VLEN(arr[first_array])
numrec = np.sum(len_array)
ret = np.empty(numrec, dtype=dt)

for c in fields:
if _is_object_field(arr, c):
# FIXME: this is rather inefficient since the stack
# is copied over to the return value
stack = np.hstack(arr[c])
if len(stack) != numrec:
dtype = []
len_array = None

if fields is None:
fields = arr.dtype.names

# Construct dtype and check consistency
for field in fields:
dt = arr.dtype[field]
if dt == 'O' or len(dt.shape):
if dt == 'O':
# Variable-length array field
lengths = VLEN(arr[field])
else:
lengths = np.repeat(dt.shape[0], arr.shape[0])
# Fixed-length array field
if len_array is None:
len_array = lengths
elif not np.array_equal(lengths, len_array):
raise ValueError(
"Array lengths do not match: "
"expected %d but found %d in %s" %
(numrec, len(stack), c))
ret[c] = stack
"inconsistent lengths of array columns in input")
if dt == 'O':
dtype.append((field, arr[field][0].dtype))
else:
dtype.append((field, arr[field].dtype, dt.shape[1:]))
else:
# Scalar field
dtype.append((field, dt))

if len_array is None:
raise RuntimeError("no array column in input")

# Build stretched output
ret = np.empty(np.sum(len_array), dtype=dtype)
for field in fields:
dt = arr.dtype[field]
if dt == 'O' or len(dt.shape) == 1:
# Variable-length or 1D fixed-length array field
ret[field] = np.hstack(arr[field])
elif len(dt.shape):
# Multidimensional fixed-length array field
ret[field] = np.vstack(arr[field])
else:
# FIXME: this is rather inefficient since the repeat result
# is copied over to the return value
ret[c] = np.repeat(arr[c], len_array)
# Scalar field
ret[field] = np.repeat(arr[field], len_array)

return ret

Expand Down
85 changes: 41 additions & 44 deletions root_numpy/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,56 +574,53 @@ def test_fill_graph():


def test_stretch():
nrec = 5
arr = np.empty(nrec,
arr = np.empty(5,
dtype=[
('scalar', np.int),
('df1', 'O'),
('df2', 'O'),
('df3', 'O')])

for i in range(nrec):
df1 = np.array(range(i + 1), dtype=np.float)
df2 = np.array(range(i + 1), dtype=np.int) * 2
df3 = np.array(range(i + 1), dtype=np.double) * 3
arr[i] = (i, df1, df2, df3)
('vl1', 'O'),
('vl2', 'O'),
('vl3', 'O'),
('fl1', np.int, (2, 2)),
('fl2', np.float, (2, 3)),
('fl3', np.double, (3, 2))])

for i in range(arr.shape[0]):
vl1 = np.array(range(i + 1), dtype=np.int)
vl2 = np.array(range(i + 2), dtype=np.float) * 2
vl3 = np.array(range(2), dtype=np.double) * 3
fl1 = np.array(range(4), dtype=np.int).reshape((2, 2))
fl2 = np.array(range(6), dtype=np.float).reshape((2, 3))
fl3 = np.array(range(6), dtype=np.double).reshape((3, 2))
arr[i] = (i, vl1, vl2, vl3, fl1, fl2, fl3)

# no array columns included
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])

stretched = rnp.stretch(
arr, ['scalar', 'df1', 'df2', 'df3'])
# lengths don't match
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'vl1', 'vl2',])
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'fl1', 'fl3',])
assert_raises(ValueError, rnp.stretch, arr)

# variable-length stretch
stretched = rnp.stretch(arr, ['scalar', 'vl1',])
assert_equal(stretched.dtype,
[('scalar', np.int),
('df1', np.float),
('df2', np.int),
('df3', np.double)])
assert_equal(stretched.size, 15)

assert_almost_equal(stretched['df1'][14], 4.0)
assert_almost_equal(stretched['df2'][14], 8)
assert_almost_equal(stretched['df3'][14], 12.0)
assert_almost_equal(stretched['scalar'][14], 4)
assert_almost_equal(stretched['scalar'][13], 4)
assert_almost_equal(stretched['scalar'][12], 4)
assert_almost_equal(stretched['scalar'][11], 4)
assert_almost_equal(stretched['scalar'][10], 4)
assert_almost_equal(stretched['scalar'][9], 3)

arr = np.empty(1, dtype=[('scalar', np.int),])
arr[0] = (1,)
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
[('scalar', np.int),
('vl1', np.int)])
assert_equal(stretched.shape[0], 15)
assert_array_equal(
stretched['scalar'],
np.repeat(arr['scalar'], np.vectorize(len)(arr['vl1'])))

nrec = 5
arr = np.empty(nrec,
dtype=[
('scalar', np.int),
('df1', 'O'),
('df2', 'O')])

for i in range(nrec):
df1 = np.array(range(i + 1), dtype=np.float)
df2 = np.array(range(i + 2), dtype=np.int) * 2
arr[i] = (i, df1, df2)
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'df1', 'df2'])
# fixed-length stretch
stretched = rnp.stretch(arr, ['scalar', 'vl3', 'fl1', 'fl2',])
assert_equal(stretched.dtype,
[('scalar', np.int),
('vl3', np.double),
('fl1', np.int, (2,)),
('fl2', np.float, (3,))])
assert_equal(stretched.shape[0], 10)
assert_array_equal(
stretched['scalar'], np.repeat(arr['scalar'], 2))


def test_blockwise_inner_join():
Expand Down