diff --git a/src/polyglot/data.py b/src/polyglot/data.py index 5028d0c2949..14a7bfdc001 100644 --- a/src/polyglot/data.py +++ b/src/polyglot/data.py @@ -118,36 +118,46 @@ class Variable(object): fully described outside the context of its parent Dataset. """ def __init__(self, dims, data, attributes=None): - object.__setattr__(self, 'dimensions', dims) - object.__setattr__(self, 'data', data) + self._dimensions = tuple(dims) + if len(dims) != data.ndim: + raise ValueError('data must have same shape as the number of ' + 'dimensions') + self._data = data if attributes is None: attributes = {} - object.__setattr__(self, 'attributes', AttributesDict(attributes)) + self._attributes = AttributesDict(attributes) - def _allocate(self): - return self.__class__(dims=(), data=0) + @property + def dimensions(self): + return self._dimensions + + @property + def data(self): + return self._data + + @data.setter + def data(self, value): + if value.shape != self.shape: + raise ValueError("replacement data must match the Variable's " + "shape") + self._data = value + @property + def attributes(self): + return self._attributes + + #TODO: replace these with explicit properties? def __getattribute__(self, key): """ Here we give some of the attributes of self.data preference over attributes in the object instelf. """ if key in ['dtype', 'shape', 'size', 'ndim', 'nbytes', - 'flat', '__iter__', 'view']: + 'flat', '__iter__']: return getattr(self.data, key) else: return object.__getattribute__(self, key) - def __setattr__(self, attr, value): - """"__setattr__ is overloaded to prevent operations that could - cause loss of data consistency. If you really intend to update - dir(self), use the self.__dict__.update method or the - super(type(a), self).__setattr__ method to bypass.""" - raise AttributeError("Object is tamper-proof") - - def __delattr__(self, attr): - raise AttributeError("Object is tamper-proof") - def __getitem__(self, index): """__getitem__ is overloaded to access the underlying numpy data""" return self.data[index] @@ -156,17 +166,12 @@ def __setitem__(self, index, data): """__setitem__ is overloaded to access the underlying numpy data""" self.data[index] = data - def __hash__(self): - """__hash__ is overloaded to guarantee that two variables with the same - attributes and np.data values have the same hash (the converse is not true)""" - return hash((self.dimensions, - frozenset((k,v.tostring()) if isinstance(v,np.ndarray) else (k,v) - for (k,v) in self.attributes.items()), - self.data.tostring())) + # mutable objects should not have a hash + __hash__ = None def __len__(self): """__len__ is overloaded to access the underlying numpy data""" - return self.data.__len__() + return len(self.data) def copy(self): """ @@ -175,12 +180,13 @@ def copy(self): return self.__copy__() def _copy(self, deepcopy=False): - data = self.data[:].copy() if deepcopy else self.data - obj = self._allocate() - object.__setattr__(obj, 'dimensions', copy.copy(self.dimensions)) - object.__setattr__(obj, 'data', data) - object.__setattr__(obj, 'attributes', self.attributes.copy()) - return obj + # dimensions is already an immutable tuple + dims = self.dimensions + data = copy.deepcopy(self.data) if deepcopy else self.data + # deepcopy attributes for sanity since there should be essentially no + # performance penalty + attributes = copy.deepcopy(self.attributes) + return type(self)(dims, data, attributes) def __copy__(self): return self._copy(deepcopy=False) @@ -194,15 +200,15 @@ def __deepcopy__(self, memo=None): return self._copy(deepcopy=True) def __eq__(self, other): - if self.dimensions != other.dimensions or \ - (self.data.tostring() != other.data.tostring()): - return False - if not self.attributes == other.attributes: + try: + return (self.dimensions == other.dimensions + and np.all(self.data[:] == other.data[:]) + and self.attributes == other.attributes) + except AttributeError: return False - return True def __ne__(self, other): - return not self.__eq__(other) + return not self == other def __str__(self): """Create a ncdump-like summary of the object""" @@ -251,9 +257,11 @@ def views(self, slicers): for i, dim in enumerate(self.dimensions): if dim in slicers: slices[i] = slicers[dim] - # Shallow copy - obj = copy.copy(self) - object.__setattr__(obj, 'data', self.data[slices]) + obj = self.copy() + obj._dimensions = tuple(dim for s, dim in zip(slices, self.dimensions) + if not isinstance(s, int)) + obj._data = self.data[slices] + assert len(obj.dimensions) == obj.ndim return obj def view(self, s, dim): @@ -265,9 +273,7 @@ def view(self, s, dim): s : slice The slice representing the range of the values to extract. dim : string - The dimension to slice along. If multiple dimensions equal - dim (e.g. a correlation matrix), then the slicing is done - only along the first matching dimension. + The dimension to slice along. Returns ------- @@ -305,6 +311,10 @@ def take(self, indices, dim): as the original. Data contents are taken along the specified dimension. + Notes + ----- + This operation does NOT preserve lazy data. + See Also -------- numpy.take @@ -315,12 +325,40 @@ def take(self, indices, dim): # When dim appears repeatedly in self.dimensions, using the index() # method gives us only the first one, which is the desired behavior axis = list(self.dimensions).index(dim) - # Deep copy - obj = copy.deepcopy(self) + obj = self.copy() # In case data is lazy we need to slice out all the data before taking. - object.__setattr__(obj, 'data', self.data[:].take(indices, axis=axis)) + obj._data = self.data[:].take(indices, axis=axis) return obj + def transpose(self, dimensions=None): + """Return a new Variable object with transposed dimensions + + Parameters + ---------- + dimensions : list of string, optional + By default, reverse the dimensions, otherwise permute the axes + according to the values given. + + Returns + ------- + obj : Variable object + The returned object has transposed data and dimensions with the + same attributes as the original. + + Notes + ----- + This operation does NOT preserve lazy data. + + See Also + -------- + numpy.transpose + """ + if dimensions is None: + dimensions = self.dimensions[::-1] + axes = [self.dimensions.index(dim) for dim in dimensions] + data = self.data[:].transpose(axes) + return type(self)(dimensions, data, self.attributes) + class Dataset(object): """ @@ -338,6 +376,7 @@ def _load_scipy(self, scipy_nc, *args, **kwdargs): try: nc = netcdf.netcdf_file(scipy_nc, mode='r', *args, **kwdargs) except: + #FIXME: can we catch a specific exception here? scipy_nc = StringIO(scipy_nc) scipy_nc.seek(0) nc = netcdf.netcdf_file(scipy_nc, mode='r', *args, **kwdargs) @@ -347,15 +386,12 @@ def from_scipy_variable(sci_var): data = sci_var.data, attributes = sci_var._attributes) - object.__setattr__(self, 'attributes', AttributesDict()) self.attributes.update(nc._attributes) - object.__setattr__(self, 'dimensions', OrderedDict()) dimensions = OrderedDict((k, len(d)) for k, d in nc.dimensions.iteritems()) self.dimensions.update(dimensions) - object.__setattr__(self, 'variables', OrderedDict()) variables = OrderedDict((vn, from_scipy_variable(v)) for vn, v in nc.variables.iteritems()) self.variables.update(variables) @@ -367,10 +403,8 @@ def _load_netcdf4(self, netcdf_path, *args, **kwdargs): """ nc = nc4.Dataset(netcdf_path, *args, **kwdargs) - object.__setattr__(self, 'attributes', AttributesDict()) self.attributes.update(dict((k.encode(), nc.getncattr(k)) for k in nc.ncattrs())) - object.__setattr__(self, 'dimensions', OrderedDict()) dimensions = OrderedDict((k.encode(), len(d)) for k, d in nc.dimensions.iteritems()) self.dimensions.update(dimensions) @@ -382,21 +416,22 @@ def from_netcdf4_variable(nc4_var): data = nc4_var, attributes = attributes) - object.__setattr__(self, 'variables', OrderedDict()) self.variables.update(dict((vn.encode(), from_netcdf4_variable(v)) for vn, v in nc.variables.iteritems())) + #TODO: alternate constructors for loading files or python objects + def __init__(self, nc=None, *args, **kwdargs): + object.__setattr__(self, 'attributes', AttributesDict()) + object.__setattr__(self, 'dimensions', OrderedDict()) + object.__setattr__(self, 'variables', OrderedDict()) + if isinstance(nc, basestring) and not nc.startswith('CDF'): # If the initialization nc is a string and it doesn't # appear to be the contents of a netcdf file we load # it using the netCDF4 package self._load_netcdf4(nc, *args, **kwdargs) - elif nc is None: - object.__setattr__(self, 'attributes', AttributesDict()) - object.__setattr__(self, 'dimensions', OrderedDict()) - object.__setattr__(self, 'variables', OrderedDict()) - else: + elif nc is not None: # If nc is a file-like object we read it using # the scipy.io.netcdf package self._load_scipy(nc) @@ -410,8 +445,10 @@ def copy(self): def __copy__(self): obj = self._allocate() object.__setattr__(obj, 'dimensions', self.dimensions.copy()) - object.__setattr__(obj, 'variables', self.variables.copy()) - object.__setattr__(obj, 'attributes', self.attributes.copy()) + object.__setattr__(obj, 'variables', + OrderedDict((k, v.copy()) for k, v + in self.variables.iteritems())) + object.__setattr__(obj, 'attributes', copy.deepcopy(self.attributes)) return obj def __setattr__(self, attr, value): @@ -424,23 +461,20 @@ def __setattr__(self, attr, value): def __contains__(self, key): """ The 'in' operator will return true or false depending on - whether 'key' is a varibale in the data object or not. + whether 'key' is a variable in the data object or not. """ return key in self.variables def __eq__(self, other): - if not isinstance(other, Dataset): - return False - if dict(self.dimensions) != dict(other.dimensions): - return False - if not dict(self.variables) == dict(other.variables): - return False - if not self.attributes == other.attributes: + try: + return (self.dimensions == other.dimensions and + self.variables == other.variables and + self.attributes == other.attributes) + except AttributeError: return False - return True def __ne__(self, other): - return not self.__eq__(other) + return not self == other @property def coordinates(self): @@ -739,26 +773,20 @@ def views(self, slicers): numpy.take Variable.take """ - if not all([isinstance(sl, slice) for sl in slicers.values()]): - raise ValueError("view expects a dict whose values are slice objects") - if not all([k in self.dimensions for k in slicers.keys()]): + if not all(k in self.dimensions for k in slicers.keys()): invalid = [k for k in slicers.keys() if not k in self.dimensions] raise KeyError("dimensions %s don't exist" % ', '.join(map(str, invalid))) - # Create a new object - obj = self._allocate() - # Create views onto the variables and infer the new dimension length - new_dims = dict(self.dimensions.iteritems()) + + obj = type(self)() + new_dims = self.dimensions.copy() for (name, var) in self.variables.iteritems(): - var_slicers = dict((k, v) for k, v in slicers.iteritems() if k in var.dimensions) - if len(var_slicers): - obj.unchecked_add_variable(name, var.views(var_slicers)) - new_dims.update(dict(zip(obj[name].dimensions, obj[name].shape))) - else: - obj.unchecked_add_variable(name, var) - # Hard write the dimensions, skipping validation + var_slicers = dict((k, v) for k, v in slicers.iteritems() + if k in var.dimensions) + obj.unchecked_add_variable(name, var.views(var_slicers)) + new_dims.update(dict(zip(obj[name].dimensions, obj[name].shape))) + obj.unchecked_set_dimensions(new_dims) - # Reference to the attributes, this intentionally does not copy. - obj.unchecked_set_attributes(self.attributes) + obj.unchecked_set_attributes(copy.deepcopy(self.attributes)) return obj def view(self, s, dim=None): @@ -770,11 +798,7 @@ def view(self, s, dim=None): s : slice The slice representing the range of the values to extract. dim : string, optional - The dimension to slice along. If multiple dimensions of a - variable equal dim (e.g. a correlation matrix), then that - variable is sliced only along both dimensions. Without - this behavior the resulting data object would have - inconsistent dimensions. + The dimension to slice along. Returns ------- @@ -800,7 +824,7 @@ def view(self, s, dim=None): raise IndexError("view results in a dimension of length zero") return obj - def take(self, indices, dim=None): + def take(self, indices, dim): """Return a new object whose contents are taken from the current object along a specified dimension @@ -813,9 +837,6 @@ def take(self, indices, dim=None): The dimension to slice along. If multiple dimensions of a variable equal dim (e.g. a correlation matrix), then that variable is sliced only along its first matching dimension. - If None (default), then the object is sliced along its - unlimited dimension; an exception is raised if the object - does not have an unlimited dimension. Returns ------- @@ -832,8 +853,6 @@ def take(self, indices, dim=None): numpy.take Variable.take """ - if dim is None: - raise ValueError("dim cannot be None") # Create a new object obj = self._allocate() # Create fancy-indexed variables and infer the new dimension length @@ -902,9 +921,14 @@ def renamed(self, name_dict): self.unchecked_set_attributes(self.attributes.copy()) return obj + def updated(self, other): + """ + + """ + def update(self, other): """ - An update method (simular to dict.update) for data objects whereby each + An update method (similar to dict.update) for data objects whereby each dimension, variable and attribute from 'other' is updated in the current object. Note however that because Data object attributes are often write protected an exception will be raised if an attempt to overwrite @@ -913,12 +937,6 @@ def update(self, other): # if a dimension is a new one it gets added, if the dimension already # exists we confirm that they are identical (or throw an exception) for (name, length) in other.dimensions.iteritems(): - if (name == other.record_dimension and - name != self.record_dimension): - raise ValueError( - ("record dimensions do not match: " - "self: %s, other: %s") % - (self.record_dimension, other.record_dimension)) if not name in self.dimensions: self.create_dimension(name, length) else: @@ -995,6 +1013,8 @@ def select(self, var): obj.unchecked_set_attributes(self.attributes.copy()) return obj + #TODO: move iterator and iterarray to Variable or the new "Cube" object + def iterator(self, dim=None, views=False): """Iterator along a data dimension @@ -1103,7 +1123,7 @@ def iterator(self, dim=None, views=False): for i in xrange(n): yield (None, self.take(np.array([i]), dim=dim)) - def iterarray(self, var, dim=None): + def iterarray(self, var, dim): """Iterator along a data dimension returning the corresponding slices of the underlying data of a varaible. @@ -1117,10 +1137,7 @@ def iterarray(self, var, dim=None): The variable over which you want to iterate. dim : string, optional - The dimension along which you want to iterate. If None - (default), then the iterator operates along the record - dimension; if there is no record dimension, an exception - will be raised. + The dimension along which you want to iterate. Returns ------- diff --git a/test/test_data.py b/test/test_data.py index 3d2873ac80b..d60b6f21657 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -118,8 +118,6 @@ def test_variable(self): self.assertFalse(v1 == v3) self.assertFalse(v1 == v4) self.assertFalse(v1 == v5) - # Variable hash - self.assertEquals(hash(v1), hash(v2)) def test_coordinate(self): a = Dataset() @@ -230,6 +228,8 @@ def test_views(self): ret = data.views(slicers) data.views(slicers) + # Verify dimensions still have matching items + self.assertItemsEqual(ret.dimensions, data.dimensions) # Verify that only the specified dimension was altered for d in data.dimensions: if d in slicers: