diff --git a/setup.py b/setup.py index 115b7e54791..00bbd6955a4 100644 --- a/setup.py +++ b/setup.py @@ -15,4 +15,4 @@ url='https://github.com/akleeman/scidata', test_suite='nose.collector', packages=['polyglot'], - package_dir={'polyglot': 'src/polyglot'}) + package_dir={'': 'src'}) diff --git a/src/polyglot/__init__.py b/src/polyglot/__init__.py index b14e6b51e88..ed10a58ebe0 100644 --- a/src/polyglot/__init__.py +++ b/src/polyglot/__init__.py @@ -1 +1,4 @@ -from data import Dataset, Variable \ No newline at end of file +from data import Dataset +from variable import Variable + +import backends \ No newline at end of file diff --git a/src/polyglot/backends.py b/src/polyglot/backends.py new file mode 100644 index 00000000000..ef90673692a --- /dev/null +++ b/src/polyglot/backends.py @@ -0,0 +1,239 @@ +import iris +import netCDF4 as nc4 + +from scipy.io import netcdf +from collections import OrderedDict + +import variable + +class InMemoryDataStore(object): + """ + Stores dimensions, variables and attributes + in ordered dictionaries, making this store + fast compared to stores which store to disk. + """ + + def __init__(self): + self.unchecked_set_attributes(variable.AttributesDict()) + self.unchecked_set_dimensions(OrderedDict()) + self.unchecked_set_variables(OrderedDict()) + + def unchecked_set_dimensions(self, dimensions): + """Set the dimensions without checking validity""" + self.dimensions = dimensions + + def unchecked_set_attributes(self, attributes): + """Set the attributes without checking validity""" + self.attributes = attributes + + def unchecked_set_variables(self, variables): + """Set the variables without checking validity""" + self.variables = variables + + def unchecked_create_dimension(self, name, length): + """Set a dimension length""" + self.dimensions[name] = length + + def unchecked_add_variable(self, name, variable): + """Add a variable without checks""" + self.variables[name] = variable + return self.variables[name] + + def unchecked_create_variable(self, name, dims, data, attributes): + """Creates a variable without checks""" + v = variable.Variable(dims=dims, data=data, + attributes=attributes) + self._unchecked_add_variable(name, v) + return v + + def unchecked_create_coordinate(self, name, data, attributes): + """Creates a coordinate (dim and var) without checks""" + self._unchecked_create_dimension(name, data.size) + return self._unchecked_create_variable(name, (name,), data, attributes) + + + def sync(self): + pass + +class ScipyVariable(variable.Variable): + + def __init__(self, scipy_var): + object.__setattr__(self, 'v', scipy_var) + + def _allocate(self): + return variable.Variable(dims=(), data=0) + + @property + def attributes(self): + return self.v._attributes + + def __getattribute__(self, key): + """ + Here we give some of the attributes of self.data preference over + attributes in the object itself. + """ + if key == 'v': + return object.__getattribute__(self, 'v') + elif hasattr(self.v, key): + return object.__getattribute__(self.v, key) + elif not hasattr(self, key) and hasattr(self.v.data, key): + return getattr(self.v.data, key) + else: + return object.__getattribute__(self, key) + +class ScipyDataStore(object): + """ + Stores data using the scipy.io.netcdf package. + This store has the advantage of being able to + be initialized with a StringIO object, allow for + serialization. + """ + def __init__(self, fobj, *args, **kwdargs): + self.ds = netcdf.netcdf_file(fobj, *args, **kwdargs) + + @property + def variables(self): + return OrderedDict((k, ScipyVariable(v)) + for k, v in self.ds.variables.iteritems()) + + @property + def attributes(self): + return self.ds._attributes + + @property + def dimensions(self): + return self.ds.dimensions + + def unchecked_set_dimensions(self, dimensions): + """Set the dimensions without checking validity""" + for d, l in dimensions.iteritems(): + self.unchecked_create_dimension(d, l) + + def unchecked_set_attributes(self, attributes): + """Set the attributes without checking validity""" + for k, v in attributes.iteritems(): + setattr(self.ds, k, v) + + def unchecked_set_variables(self, variables): + """Set the variables without checking validity""" + for vn, v in variables.iteritems(): + self.unchecked_add_variable(vn, v) + + def unchecked_create_dimension(self, name, length): + """Set a dimension length""" + self.ds.createDimension(name, length) + + def unchecked_add_variable(self, name, variable): + """Add a variable without checks""" + self.ds.createVariable(name, variable.dtype, + variable.dimensions) + self.ds.variables[name][:] = variable.data[:] + for k, v in variable.attributes.iteritems(): + setattr(self.ds.variables[name], k, v) + + def unchecked_create_coordinate(self, name, data, attributes): + """Creates a coordinate (dim and var) without checks""" + self.unchecked_create_dimension(name, data.size) + return self.unchecked_create_variable(name, (name,), data, attributes) + + def sync(self): + self.ds.flush() + +class NetCDF4Variable(variable.Variable): + + def __init__(self, nc4_variable): + object.__setattr__(self, 'data', + variable.LazyVariableData(nc4_variable)) + object.__setattr__(self, '_attributes', None) + + def _allocate(self): + return variable.Variable(dims=(), data=0) + + @property + def attributes(self): + if self._attributes is None: + # we don't want to see scale_factor and add_offset in the attributes + # since the netCDF4 package automatically scales the data on read. + # If we kept scale_factor and add_offset around and did this: + # + # foo = ncdf4.Dataset('foo.nc') + # ncdf4.dump(foo, 'bar.nc') + # bar = ncdf4.Dataset('bar.nc') + # + # you would find that any packed variables in the original + # netcdf file would now have been scaled twice! + packing_attributes = ['scale_factor', 'add_offset'] + keys = [k for k in self.ncattrs() if not k in packing_attributes] + attr_dict = variable.AttributesDict((k, self.data.getncattr(k)) + for k in keys) + object.__setattr__(self, '_attributes', attr_dict) + return self._attributes + + def __getattr__(self, attr): + """__getattr__ is overloaded to selectively expose some of the + attributes of the underlying nc4 variable""" + if attr == 'data': + return object.__getattribute__(self, 'data') + elif hasattr(self.data, attr): + return getattr(self.data, attr) + else: + return object.__getattribute__(self, attr) + +class NetCDF4DataStore(object): + + def __init__(self, filename, *args, **kwdargs): + self.ds = nc4.Dataset(filename, *args, **kwdargs) + + @property + def variables(self): + return OrderedDict((k, NetCDF4Variable(v)) + for k, v in self.ds.variables.iteritems()) + + @property + def attributes(self): + return variable.AttributesDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) + + @property + def dimensions(self): + return OrderedDict((k, len(v)) for k, v in self.ds.dimensions.iteritems()) + + def unchecked_set_dimensions(self, dimensions): + """Set the dimensions without checking validity""" + for d, l in dimensions.iteritems(): + self.unchecked_create_dimension(d, l) + + def unchecked_set_attributes(self, attributes): + """Set the attributes without checking validity""" + self.ds.setncatts(attributes) + + def unchecked_set_variables(self, variables): + """Set the variables without checking validity""" + for vn, v in variables.iteritems(): + self.unchecked_add_variable(vn, v) + + def unchecked_create_dimension(self, name, length): + """Set a dimension length""" + self.ds.createDimension(name, size=length) + + def unchecked_add_variable(self, name, variable): + """Add a variable without checks""" + # netCDF4 will automatically assign a fill value + # depending on the datatype of the variable. Here + # we let the package handle the _FillValue attribute + # instead of setting it ourselves. + fill_value = variable.attributes.pop('_FillValue', None) + self.ds.createVariable(varname=name, + datatype=variable.dtype, + dimensions=variable.dimensions, + fill_value=fill_value) + self.ds.variables[name][:] = variable.data[:] + self.ds.variables[name].setncatts(variable.attributes) + + def unchecked_create_coordinate(self, name, data, attributes): + """Creates a coordinate (dim and var) without checks""" + self.unchecked_create_dimension(name, data.size) + return self.unchecked_create_variable(name, (name,), data, attributes) + + def sync(self): + self.ds.sync() diff --git a/src/polyglot/conventions.py b/src/polyglot/conventions.py index 351d3cc9297..637aa5943b4 100644 --- a/src/polyglot/conventions.py +++ b/src/polyglot/conventions.py @@ -59,6 +59,16 @@ 'string', ]) +def pretty_print(x, numchars): + """Given an object x, call x.__str__() and format the returned + string so that it is numchars long, padding with trailing spaces or + truncating with ellipses as necessary""" + s = str(x).rstrip(NULL) + if len(s) > numchars: + return s[:(numchars - 3)] + '...' + else: + return s + def coerce_type(arr): """Coerce a numeric data type to a type that is compatible with netCDF-3 diff --git a/src/polyglot/data.py b/src/polyglot/data.py index fe99d81728b..bcadf952922 100644 --- a/src/polyglot/data.py +++ b/src/polyglot/data.py @@ -2,9 +2,6 @@ import os import copy -import itertools -import unicodedata - import numpy as np import netCDF4 as nc4 @@ -13,323 +10,95 @@ from cStringIO import StringIO from collections import OrderedDict -import conventions +import conventions, backends, variable date2num = nc4.date2num num2date = nc4.num2date -def _prettyprint(x, numchars): - """Given an object x, call x.__str__() and format the returned - string so that it is numchars long, padding with trailing spaces or - truncating with ellipses as necessary""" - s = str(x).rstrip(conventions.NULL) - if len(s) > numchars: - return s[:(numchars - 3)] + '...' - else: - return s - -class AttributesDict(OrderedDict): - """A subclass of OrderedDict whose __setitem__ method automatically - checks and converts values to be valid netCDF attributes +class Dataset(object): """ - def __init__(self, *args, **kwds): - OrderedDict.__init__(self, *args, **kwds) - - def __setitem__(self, key, value): - if not conventions.is_valid_name(key): - raise ValueError("Not a valid attribute name") - # Strings get special handling because netCDF treats them as - # character arrays. Everything else gets coerced to a numpy - # vector. netCDF treats scalars as 1-element vectors. Arrays of - # non-numeric type are not allowed. - if isinstance(value, basestring): - # netcdf attributes should be unicode - value = unicode(value) - else: - try: - value = conventions.coerce_type(np.atleast_1d(np.asarray(value))) - except: - raise ValueError("Not a valid value for a netCDF attribute") - if value.ndim > 1: - raise ValueError("netCDF attributes must be vectors " + - "(1-dimensional)") - value = conventions.coerce_type(value) - if str(value.dtype) not in conventions.TYPEMAP: - # A plain string attribute is okay, but an array of - # string objects is not okay! - raise ValueError("Can not convert to a valid netCDF type") - OrderedDict.__setitem__(self, key, value) + A netcdf-like data object consisting of dimensions, variables and + attributes which together form a self describing data set. + """ + def __init__(self, nc = None, store = None, *args, **kwdargs): - def copy(self): - """The copy method of the superclass simply calls the constructor, - which in turn calls the update method, which in turns calls - __setitem__. This subclass implementation bypasses the expensive - validation in __setitem__ for a substantial speedup.""" - obj = self.__class__() - for (attr, value) in self.iteritems(): - OrderedDict.__setitem__(obj, attr, copy.copy(value)) - return obj + if store is None: + store = backends.InMemoryDataStore() + object.__setattr__(self, 'store', store) - def __deepcopy__(self, memo=None): - """ - Returns a deep copy of the current object. + if isinstance(nc, basestring) and not nc.startswith('CDF'): + """ + If the initialization nc is a string and it doesn't + appear to be the contents of a netcdf file we load + it using the netCDF4 package + """ + self._load_netcdf4(nc, *args, **kwdargs) + elif nc is not None: + """ + If nc is a file-like object we read it using + the scipy.io.netcdf package + """ + self._load_scipy(nc) - memo does nothing but is required for compatability with copy.deepcopy - """ - return self.copy() + def _unchecked_set_dimensions(self, *args, **kwdargs): + self.store.unchecked_set_dimensions(*args, **kwdargs) - def update(self, *other, **kwargs): - """Set multiple attributes with a mapping object or an iterable of - key/value pairs""" - # Capture arguments in an OrderedDict - args_dict = OrderedDict(*other, **kwargs) - try: - # Attempt __setitem__ - for (attr, value) in args_dict.iteritems(): - self.__setitem__(attr, value) - except: - # A plain string attribute is okay, but an array of - # string objects is not okay! - raise ValueError("Can not convert to a valid netCDF type") - # Clean up so that we don't end up in a partial state - for (attr, value) in args_dict.iteritems(): - if self.__contains__(attr): - self.__delitem__(attr) - # Re-raise - raise + def _unchecked_set_attributes(self, *args, **kwdargs): + self.store.unchecked_set_attributes(*args, **kwdargs) - def __eq__(self, other): - if not set(self.keys()) == set(other.keys()): - return False - for (key, value) in self.iteritems(): - if value.__class__ != other[key].__class__: - return False - if isinstance(value, basestring): - if value != other[key]: - return False - else: - if value.tostring() != other[key].tostring(): - return False - return True + def _unchecked_set_variables(self, *args, **kwdargs): + self.store.unchecked_set_variables(*args, **kwdargs) -class Variable(object): - """ - A netcdf-like variable consisting of dimensions, data and attributes - which describe a single variable. A single variable object is not - fully described outside the context of its parent Dataset. - """ - def __init__(self, dims, data, attributes=None): - object.__setattr__(self, 'dimensions', dims) - object.__setattr__(self, 'data', data) - if attributes is None: - attributes = {} - object.__setattr__(self, 'attributes', AttributesDict(attributes)) + def _unchecked_create_dimension(self, *args, **kwdargs): + self.store.unchecked_create_dimension(*args, **kwdargs) - def _allocate(self): - return self.__class__(dims=(), data=0) + def _unchecked_add_variable(self, *args, **kwdargs): + self.store.unchecked_add_variable(*args, **kwdargs) - def __getattribute__(self, key): - """ - Here we give some of the attributes of self.data preference over - attributes in the object instelf. - """ - if key in ['dtype', 'shape', 'size', 'ndim', 'nbytes', - 'flat', '__iter__', 'view']: - return getattr(self.data, key) - else: - return object.__getattribute__(self, key) + def _unchecked_create_variable(self, name, dims, data, attributes): + """Creates a variable without checks""" + v = variable.Variable(dims=dims, data=data, + attributes=attributes) + self._unchecked_add_variable(name, v) + return v - def __setattr__(self, attr, value): - """"__setattr__ is overloaded to prevent operations that could - cause loss of data consistency. If you really intend to update - dir(self), use the self.__dict__.update method or the - super(type(a), self).__setattr__ method to bypass.""" - raise AttributeError, "Object is tamper-proof" + def _unchecked_create_coordinate(self, name, data, attributes): + """Creates a coordinate (dim and var) without checks""" + self._unchecked_create_dimension(name, data.size) + return self._unchecked_create_variable(name, (name,), data, attributes) - def __delattr__(self, attr): - raise AttributeError, "Object is tamper-proof" + def sync(self): + return self.store.sync() - def __getitem__(self, index): - """__getitem__ is overloaded to access the underlying numpy data""" - return self.data[index] + @property + def variables(self): + return self.store.variables - def __setitem__(self, index, data): - """__setitem__ is overloaded to access the underlying numpy data""" - self.data[index] = data + @property + def attributes(self): + return self.store.attributes - def __hash__(self): - """__hash__ is overloaded to guarantee that two variables with the same - attributes and np.data values have the same hash (the converse is not true)""" - return hash((self.dimensions, - frozenset((k,v.tostring()) if isinstance(v,np.ndarray) else (k,v) - for (k,v) in self.attributes.items()), - self.data.tostring())) + @property + def dimensions(self): + return self.store.dimensions - def __len__(self): - """__len__ is overloaded to access the underlying numpy data""" - return self.data.__len__() + def _allocate(self): + return self.__class__() - def __copy__(self): + def copy(self): """ Returns a shallow copy of the current object. """ - # Create the simplest possible dummy object and then overwrite it - obj = self._allocate() - object.__setattr__(obj, 'dimensions', self.dimensions) - object.__setattr__(obj, 'data', self.data) - object.__setattr__(obj, 'attributes', self.attributes) - return obj + return self.__copy__() - def __deepcopy__(self, memo=None): + def __copy__(self): """ - Returns a deep copy of the current object. - - memo does nothing but is required for compatability with copy.deepcopy + Returns a shallow copy of the current object. """ - # Create the simplest possible dummy object and then overwrite it obj = self._allocate() - # tuples are immutable - object.__setattr__(obj, 'dimensions', self.dimensions) - object.__setattr__(obj, 'data', self.data[:].copy()) - object.__setattr__(obj, 'attributes', self.attributes.copy()) - return obj - - def __eq__(self, other): - if self.dimensions != other.dimensions or \ - (self.data.tostring() != other.data.tostring()): - return False - if not self.attributes == other.attributes: - return False - return True - - def __ne__(self, other): - return not self.__eq__(other) - - def __str__(self): - """Create a ncdump-like summary of the object""" - summary = ["dimensions:"] - # prints dims that look like: - # dimension = length - dim_print = lambda d, l : "\t%s : %s" % (_prettyprint(d, 30), - _prettyprint(l, 10)) - # add each dimension to the summary - summary.extend([dim_print(d, l) for d, l in zip(self.dimensions, self.shape)]) - summary.append("type : %s" % (_prettyprint(var.dtype, 8))) - summary.append("\nattributes:") - # attribute:value - summary.extend(["\t%s:%s" % (_prettyprint(att, 30), - _prettyprint(val, 30)) - for att, val in self.attributes.iteritems()]) - # create the actual summary - return '\n'.join(summary) - - def views(self, slicers): - """Return a new Variable object whose contents are a view of the object - sliced along a specified dimension. - - Parameters - ---------- - slicers : {dim: slice, ...} - A dictionary mapping from dim to slice, dim represents - the dimension to slice along slice represents the range of the - values to extract. - - Returns - ------- - obj : Variable object - The returned object has the same attributes and dimensions - as the original. Data contents are taken along the - specified dimension. Care must be taken since modifying (most) - values in the returned object will result in modification to the - parent object. - - See Also - -------- - view - take - """ - slices = [slice(None)] * self.data.ndim - for i, dim in enumerate(self.dimensions): - if dim in slicers: - slices[i] = slicers[dim] - # Shallow copy - obj = copy.copy(self) - object.__setattr__(obj, 'data', self.data[slices]) + self.translate(obj, copy=True) return obj - def view(self, s, dim): - """Return a new Variable object whose contents are a view of the object - sliced along a specified dimension. - - Parameters - ---------- - s : slice - The slice representing the range of the values to extract. - dim : string - The dimension to slice along. If multiple dimensions equal - dim (e.g. a correlation matrix), then the slicing is done - only along the first matching dimension. - - Returns - ------- - obj : Variable object - The returned object has the same attributes and dimensions - as the original. Data contents are taken along the - specified dimension. Care must be taken since modifying (most) - values in the returned object will result in modification to the - parent object. - - See Also - -------- - take - """ - return self.views({dim : s}) - - def take(self, indices, dim): - """Return a new Variable object whose contents are sliced from - the current object along a specified dimension - - Parameters - ---------- - indices : array_like - The indices of the values to extract. indices must be compatible - with the ndarray.take() method. - dim : string - The dimension to slice along. If multiple dimensions equal - dim (e.g. a correlation matrix), then the slicing is done - only along the first matching dimension. - - Returns - ------- - obj : Variable object - The returned object has the same attributes and dimensions - as the original. Data contents are taken along the - specified dimension. - - See Also - -------- - numpy.take - """ - indices = np.asarray(indices) - if indices.ndim != 1: - raise ValueError('indices should have a single dimension') - # When dim appears repeatedly in self.dimensions, using the index() - # method gives us only the first one, which is the desired behavior - axis = list(self.dimensions).index(dim) - # Deep copy - obj = copy.deepcopy(self) - # In case data is lazy we need to slice out all the data before taking. - object.__setattr__(obj, 'data', self.data[:].take(indices, axis=axis)) - return obj - -class Dataset(object): - """ - A netcdf-like data object consisting of dimensions, variables and - attributes which together form a self describing data set. - """ - def _allocate(self): - return self.__class__() - def _load_scipy(self, scipy_nc, *args, **kwdargs): """ Interprets a netcdf file-like object using scipy.io.netcdf. @@ -342,23 +111,14 @@ def _load_scipy(self, scipy_nc, *args, **kwdargs): scipy_nc.seek(0) nc = netcdf.netcdf_file(scipy_nc, mode='r', *args, **kwdargs) - def from_scipy_variable(sci_var): - return Variable(dims = sci_var.dimensions, - data = sci_var.data, - attributes = sci_var._attributes) - - object.__setattr__(self, 'attributes', AttributesDict()) self.attributes.update(nc._attributes) - - object.__setattr__(self, 'dimensions', OrderedDict()) - dimensions = OrderedDict((k, len(d)) - for k, d in nc.dimensions.iteritems()) - self.dimensions.update(dimensions) - - object.__setattr__(self, 'variables', OrderedDict()) - OrderedDict = OrderedDict((vn, from_scipy_variable(v)) - for vn, v in nc.variables.iteritems()) - self.variables.update() + for k, d in nc.dimensions.iteritems(): + self._unchecked_create_dimension(k, d) + for vn, sci_var in nc.variables.iteritems(): + self._unchecked_create_variable(vn, + dims = sci_var.dimensions, + data = sci_var.data, + attributes = sci_var._attributes) def _load_netcdf4(self, netcdf_path, *args, **kwdargs): """ @@ -367,45 +127,20 @@ def _load_netcdf4(self, netcdf_path, *args, **kwdargs): """ nc = nc4.Dataset(netcdf_path, *args, **kwdargs) - object.__setattr__(self, 'attributes', AttributesDict()) self.attributes.update(dict((k.encode(), nc.getncattr(k)) for k in nc.ncattrs())) - object.__setattr__(self, 'dimensions', OrderedDict()) - dimensions = OrderedDict((k.encode(), len(d)) for k, d in nc.dimensions.iteritems()) - self.dimensions.update(dimensions) + for k, d in nc.dimensions.iteritems(): + self._unchecked_create_dimension(k.encode(), len(d)) - def from_netcdf4_variable(nc4_var): - attributes = dict((k, nc4_var.getncattr(k)) for k in nc4_var.ncattrs()) - return Variable(dims = tuple(nc4_var.dimensions), + for vn, v in nc.variables.iteritems(): + attributes = dict((k, v.getncattr(k)) for k in v.ncattrs()) + self._unchecked_create_variable(vn, + dims = tuple(v.dimensions), # TODO : this variable copy is lazy and # might cause issues in the future. - data = nc4_var, + data = v, attributes = attributes) - object.__setattr__(self, 'variables', OrderedDict()) - self.variables.update(dict((vn.encode(), from_netcdf4_variable(v)) - for vn, v in nc.variables.iteritems())) - - def __init__(self, nc = None, *args, **kwdargs): - if isinstance(nc, basestring) and not nc.startswith('CDF'): - """ - If the initialization nc is a string and it doesn't - appear to be the contents of a netcdf file we load - it using the netCDF4 package - """ - self._load_netcdf4(nc, *args, **kwdargs) - elif nc is None: - object.__setattr__(self, 'attributes', AttributesDict()) - object.__setattr__(self, 'dimensions', OrderedDict()) - object.__setattr__(self, 'variables', OrderedDict()) - else: - """ - If nc is a file-like object we read it using - the scipy.io.netcdf package - """ - self._load_scipy(nc) - - def __setattr__(self, attr, value): """"__setattr__ is overloaded to prevent operations that could cause loss of data consistency. If you really intend to update @@ -438,12 +173,11 @@ def __ne__(self, other): def coordinates(self): # A coordinate variable is a 1-dimensional variable with the # same name as its dimension - return OrderedDict([(dim, length) - for (dim, length) in self.dimensions.iteritems() - if (dim in self.variables) and - (self.variables[dim].data.ndim == 1) and - (self.variables[dim].dimensions == (dim,)) - ]) + return OrderedDict([(dim, self.variables[dim]) + for dim in self.dimensions + if dim in self.variables and + self.variables[dim].data.ndim == 1 and + self.variables[dim].dimensions == (dim,)]) @property def noncoordinates(self): @@ -453,25 +187,24 @@ def noncoordinates(self): for (name, v) in self.variables.iteritems() if name not in self.coordinates]) + def translate(self, target, copy=False): + dims = self.dimensions.copy() if copy else self.dimensions + variables = self.variables.copy() if copy else self.variables + atts = self.attributes.copy() if copy else self.attributes + target.store.unchecked_set_dimensions(dims) + target.store.unchecked_set_variables(variables) + target.store.unchecked_set_attributes(atts) + target.store.sync() + def dump(self, filepath, *args, **kwdargs): """ Dump the contents to a location on disk using the netCDF4 package """ - nc = nc4.Dataset(filepath, mode='w', *args, **kwdargs) - for d, l in self.dimensions.iteritems(): - nc.createDimension(d, size=l) - for vn, v in self.variables.iteritems(): - nc.createVariable(vn, v.dtype, v.dimensions) - nc.variables[vn][:] = v.data[:] - for k, a in v.attributes.iteritems(): - try: - nc.variables[vn].setncattr(k, a) - except: - import pdb; pdb.set_trace() - - nc.setncatts(self.attributes) - return nc + nc4_store = backends.NetCDF4DataStore(filepath, mode='w', + *args, **kwdargs) + out = Dataset(store=nc4_store) + self.translate(out) def dumps(self): """ @@ -479,25 +212,10 @@ def dumps(self): creates an in memory netcdf version 3 string using the scipy.io.netcdf package. """ - # TODO : this (may) effectively double the amount of - # data held in memory. It'd be nice to stream the - # serialized string. fobj = StringIO() - nc = netcdf.netcdf_file(fobj, mode='w') - # copy the dimensions - for d, l in self.dimensions.iteritems(): - nc.createDimension(d, l) - # copy the variables - for vn, v in self.variables.iteritems(): - nc.createVariable(vn, v.dtype, v.dimensions) - nc.variables[vn][:] = v.data[:] - for k, a in v.attributes.iteritems(): - setattr(nc.variables[vn], k, a) - # copy the attributes - for k, a in self.attributes.iteritems(): - setattr(nc, k, a) - # flush to the StringIO object - nc.flush() + scipy_store = backends.ScipyDataStore(fobj, mode='w') + out = Dataset(store=scipy_store) + self.translate(out) return fobj.getvalue() def __str__(self): @@ -505,8 +223,8 @@ def __str__(self): summary = ["dimensions:"] # prints dims that look like: # dimension = length - dim_print = lambda d, l : "\t%s = %s" % (_prettyprint(d, 30), - _prettyprint(l, 10)) + dim_print = lambda d, l : "\t%s = %s" % (conventions.pretty_print(d, 30), + conventions.pretty_print(l, 10)) # add each dimension to the summary summary.extend([dim_print(d, l) for d, l in self.dimensions.iteritems()]) @@ -515,18 +233,18 @@ def __str__(self): for vname, var in self.variables.iteritems(): # this looks like: # dtype name(dim1, dim2) - summary.append("\t%s %s(%s)" % (_prettyprint(var.dtype, 8), - _prettyprint(vname, 20), - _prettyprint(', '.join(var.dimensions), 45))) + summary.append("\t%s %s(%s)" % (conventions.pretty_print(var.dtype, 8), + conventions.pretty_print(vname, 20), + conventions.pretty_print(', '.join(var.dimensions), 45))) # attribute:value - summary.extend(["\t\t%s:%s" % (_prettyprint(att, 30), - _prettyprint(val, 30)) + summary.extend(["\t\t%s:%s" % (conventions.pretty_print(att, 30), + conventions.pretty_print(val, 30)) for att, val in var.attributes.iteritems()]) summary.append("\nattributes:") # attribute:value - summary.extend(["\t%s:%s" % (_prettyprint(att, 30), - _prettyprint(val, 30)) + summary.extend(["\t%s:%s" % (conventions.pretty_print(att, 30), + conventions.pretty_print(val, 30)) for att, val in self.attributes.iteritems()]) # create the actual summary return '\n'.join(summary) @@ -537,12 +255,6 @@ def __getitem__(self, key): else: raise ValueError("%s is not a variable" % key) - def unchecked_set_dimensions(self, dimensions): - object.__setattr__(self, 'dimensions', dimensions) - - def unchecked_create_dimension(self, name, length): - self.dimensions[name] = length - def create_dimension(self, name, length): """Adds a dimension with name dim and length to the object @@ -564,18 +276,7 @@ def create_dimension(self, name, length): if not isinstance(length, int): raise TypeError("Dimension length must be int") assert length >= 0 - self.unchecked_create_dimension(name, length) - - def unchecked_set_attributes(self, attributes): - object.__setattr__(self, 'attributes', attributes) - - def unchecked_add_variable(self, name, variable): - self.variables[name] = variable - return self.variables[name] - - def unchecked_create_variable(self, name, dims, data, attributes): - v = Variable(dims=dims, data=data, attributes=attributes) - return self.unchecked_add_variable(name, v) + self._unchecked_create_dimension(name, length) def create_variable(self, name, dims, data, attributes=None): """Create a new variable. @@ -591,13 +292,8 @@ def create_variable(self, name, dims, data, attributes=None): dims : tuple The dimensions of the new variable. Elements must be dimensions of the object. - data : numpy.ndarray or None, optional - Data to populate the new variable. If None (default), then - an empty numpy array is allocated with the appropriate - shape and dtype. If data contains int64 integers, it will - be coerced to int32 (for the sake of netCDF compatibility), - and an exception will be raised if this coercion is not - safe. + data : numpy.ndarray + Data to populate the new variable. attributes : dict_like or None, optional Attributes to assign to the new variable. Attribute names must be unique and must satisfy netCDF-3 naming rules. If @@ -629,11 +325,7 @@ def create_variable(self, name, dims, data, attributes=None): if (name in self.dimensions) and (data.ndim != 1): raise ValueError("A coordinate variable must be defined with " + "1-dimensional data") - return self.unchecked_create_variable(name, dims, data, attributes) - - def unchecked_create_coordinate(self, name, data, attributes): - self.unchecked_create_dimension(name, data.size) - return self.unchecked_create_variable(name, (name,), data, attributes) + return self._unchecked_create_variable(name, dims, data, attributes) def create_coordinate(self, name, data, attributes=None): """Create a new dimension and a corresponding coordinate variable. @@ -676,7 +368,7 @@ def create_coordinate(self, name, data, attributes=None): # end up in a partial state. if data.ndim != 1: raise ValueError("coordinate must have ndim==1") - return self.unchecked_create_coordinate(name, data, attributes) + return self._unchecked_create_coordinate(name, data, attributes) def add_variable(self, name, variable): """A convenience function for adding a variable from one object to @@ -708,7 +400,6 @@ def delete_variable(self, name): raise ValueError("Object does not have a variable '%s'" % (str(name))) else: - super(type(self.variables), self.variables).__delitem__(name) def views(self, slicers): @@ -751,14 +442,14 @@ def views(self, slicers): for (name, var) in self.variables.iteritems(): var_slicers = dict((k, v) for k, v in slicers.iteritems() if k in var.dimensions) if len(var_slicers): - obj.unchecked_add_variable(name, var.views(var_slicers)) + obj.store.unchecked_add_variable(name, var.views(var_slicers)) new_dims.update(dict(zip(obj[name].dimensions, obj[name].shape))) else: - obj.unchecked_add_variable(name, var) + obj.store.unchecked_add_variable(name, var) # Hard write the dimensions, skipping validation - obj.unchecked_set_dimensions(new_dims) + obj.store.unchecked_set_dimensions(new_dims) # Reference to the attributes, this intentionally does not copy. - obj.unchecked_set_attributes(self.attributes) + obj.store.unchecked_set_attributes(self.attributes) return obj def view(self, s, dim=None): @@ -840,27 +531,27 @@ def take(self, indices, dim=None): new_length = self.dimensions[dim] for (name, var) in self.variables.iteritems(): if dim in var.dimensions: - obj.unchecked_add_variable(name, var.take(indices, dim)) + obj.store.unchecked_add_variable(name, var.take(indices, dim)) new_length = obj.variables[name].data.shape[ list(var.dimensions).index(dim)] else: - obj.unchecked_add_variable(name, copy.deepcopy(var)) + obj.store.unchecked_add_variable(name, copy.deepcopy(var)) # Hard write the dimensions, skipping validation for d, l in self.dimensions.iteritems(): if d == dim: l = new_length - obj.unchecked_create_dimension(d, l) + obj.store.unchecked_create_dimension(d, l) if obj.dimensions[dim] == 0: raise IndexError( "take would result in a dimension of length zero") # Copy attributes - self.unchecked_set_attributes(self.attributes.copy()) + self._unchecked_set_attributes(self.attributes.copy()) return obj def renamed(self, name_dict): """ - Returns a copy of the current object with variables and dimensions - reanmed according to the arguments passed via **kwds + Returns a new object with variables and dimensions renamed according to + the arguments passed via **kwds Parameters ---------- @@ -890,16 +581,16 @@ def renamed(self, name_dict): # if a dimension is a new one it gets added, if the dimension already # exists we confirm that they are identical (or throw an exception) for (name, length) in self.dimensions.iteritems(): - obj.create_dimension(new_names[name], length) + obj._unchecked_create_dimension(new_names[name], length) # a variable is only added if it doesn't currently exist, otherwise # and exception is thrown for (name, v) in self.variables.iteritems(): - obj.create_variable(new_names[name], - tuple([new_names[d] for d in v.dimensions]), - data=v.data.copy(), - attributes=v.attributes.copy()) + obj._unchecked_create_variable(new_names[name], + dims=tuple([new_names[d] for d in v.dimensions]), + data=v.data, + attributes=v.attributes.copy()) # update the root attributes - self.unchecked_set_attributes(self.attributes.copy()) + obj._unchecked_set_attributes(self.attributes.copy()) return obj def update(self, other): @@ -935,7 +626,7 @@ def update(self, other): if not name in self.variables: self.create_variable(name, v.dimensions, - data=v.data.copy(), + data=v.data, attributes=v.attributes.copy()) else: if self[name].dimensions != other[name].dimensions: @@ -983,16 +674,16 @@ def select(self, var): dim = reduce(or_, [set(self.variables[v].dimensions) for v in var]) # Create dimensions in the same order as they appear in self.dimension for d in dim: - obj.unchecked_create_dimension(d, self.dimensions[d]) + obj.store.unchecked_create_dimension(d, self.dimensions[d]) # Also include any coordinate variables defined on the relevant # dimensions for (name, v) in self.variables.iteritems(): if (name in var) or ((name in dim) and (v.dimensions == (name,))): - obj.unchecked_create_variable(name, + obj._unchecked_create_variable(name, dims=v.dimensions, - data=v.data.copy(), + data=v.data, attributes=v.attributes.copy()) - obj.unchecked_set_attributes(self.attributes.copy()) + obj._unchecked_set_attributes(self.attributes.copy()) return obj def iterator(self, dim=None, views=False): @@ -1186,7 +877,7 @@ def squeeze(self, dimension): for (name, var) in self.variables.iteritems(): if not name == dimension: dims = list(var.dimensions) - data = var.data.copy() + data = var.data if dimension in dims: shape = list(var.data.shape) index = dims.index(dimension) @@ -1197,7 +888,7 @@ def squeeze(self, dimension): dims=tuple(dims), data=data, attributes=var.attributes.copy()) - obj.unchecked_set_attributes(self.attributes.copy()) + obj.store.unchecked_set_attributes(self.attributes.copy()) return obj if __name__ == "__main__": @@ -1205,7 +896,7 @@ def squeeze(self, dimension): A bunch of regression tests. """ base_dir = os.path.dirname(__file__) - test_dir = os.path.join(base_dir, '..', 'test', ) + test_dir = os.path.join(base_dir, '..', '..', 'test', ) write_test_path = os.path.join(test_dir, 'test_output.nc') ecmwf_netcdf = os.path.join(test_dir, 'ECMWF_ERA-40_subset.nc') diff --git a/src/polyglot/variable.py b/src/polyglot/variable.py new file mode 100644 index 00000000000..2bababbfae2 --- /dev/null +++ b/src/polyglot/variable.py @@ -0,0 +1,357 @@ +import copy +import numpy as np + +from collections import OrderedDict + +import conventions + +class AttributesDict(OrderedDict): + """A subclass of OrderedDict whose __setitem__ method automatically + checks and converts values to be valid netCDF attributes + """ + def __init__(self, *args, **kwds): + OrderedDict.__init__(self, *args, **kwds) + + def __setitem__(self, key, value): + if not conventions.is_valid_name(key): + raise ValueError("Not a valid attribute name") + # Strings get special handling because netCDF treats them as + # character arrays. Everything else gets coerced to a numpy + # vector. netCDF treats scalars as 1-element vectors. Arrays of + # non-numeric type are not allowed. + if isinstance(value, basestring): + # netcdf attributes should be unicode + value = unicode(value) + else: + try: + value = conventions.coerce_type(np.atleast_1d(np.asarray(value))) + except: + raise ValueError("Not a valid value for a netCDF attribute") + if value.ndim > 1: + raise ValueError("netCDF attributes must be vectors " + + "(1-dimensional)") + value = conventions.coerce_type(value) + if str(value.dtype) not in conventions.TYPEMAP: + # A plain string attribute is okay, but an array of + # string objects is not okay! + raise ValueError("Can not convert to a valid netCDF type") + OrderedDict.__setitem__(self, key, value) + + def copy(self): + """The copy method of the superclass simply calls the constructor, + which in turn calls the update method, which in turns calls + __setitem__. This subclass implementation bypasses the expensive + validation in __setitem__ for a substantial speedup.""" + obj = self.__class__() + for (attr, value) in self.iteritems(): + OrderedDict.__setitem__(obj, attr, copy.copy(value)) + return obj + + def __deepcopy__(self, memo=None): + """ + Returns a deep copy of the current object. + + memo does nothing but is required for compatability with copy.deepcopy + """ + return self.copy() + + def update(self, *other, **kwargs): + """Set multiple attributes with a mapping object or an iterable of + key/value pairs""" + # Capture arguments in an OrderedDict + args_dict = OrderedDict(*other, **kwargs) + try: + # Attempt __setitem__ + for (attr, value) in args_dict.iteritems(): + self.__setitem__(attr, value) + except: + # A plain string attribute is okay, but an array of + # string objects is not okay! + raise ValueError("Can not convert to a valid netCDF type") + # Clean up so that we don't end up in a partial state + for (attr, value) in args_dict.iteritems(): + if self.__contains__(attr): + self.__delitem__(attr) + # Re-raise + raise + + def __eq__(self, other): + if not set(self.keys()) == set(other.keys()): + return False + for (key, value) in self.iteritems(): + if value.__class__ != other[key].__class__: + return False + if isinstance(value, basestring): + if value != other[key]: + return False + else: + if value.tostring() != other[key].tostring(): + return False + return True + + +class Variable(object): + """ + A netcdf-like variable consisting of dimensions, data and attributes + which describe a single varRiable. A single variable object is not + fully described outside the context of its parent Dataset. + """ + def __init__(self, dims, data, attributes=None): + object.__setattr__(self, 'dimensions', dims) + object.__setattr__(self, 'data', data) + if attributes is None: + attributes = {} + object.__setattr__(self, 'attributes', AttributesDict(attributes)) + + def _allocate(self): + return self.__class__(dims=(), data=0) + + def __getattribute__(self, key): + """ + Here we give some of the attributes of self.data preference over + attributes in the object instelf. + """ + if key in ['dtype', 'shape', 'size', 'ndim', 'nbytes', + 'flat', '__iter__', 'view']: + return getattr(self.data, key) + else: + return object.__getattribute__(self, key) + + def __setattr__(self, attr, value): + """"__setattr__ is overloaded to prevent operations that could + cause loss of data consistency. If you really intend to update + dir(self), use the self.__dict__.update method or the + super(type(a), self).__setattr__ method to bypass.""" + raise AttributeError, "Object is tamper-proof" + + def __delattr__(self, attr): + raise AttributeError, "Object is tamper-proof" + + def __getitem__(self, index): + """__getitem__ is overloaded to access the underlying numpy data""" + return self.data[index] + + def __setitem__(self, index, data): + """__setitem__ is overloaded to access the underlying numpy data""" + self.data[index] = data + + def __hash__(self): + """__hash__ is overloaded to guarantee that two variables with the same + attributes and np.data values have the same hash (the converse is not true)""" + return hash((self.dimensions, + frozenset((k,v.tostring()) if isinstance(v,np.ndarray) else (k,v) + for (k,v) in self.attributes.items()), + self.data.tostring())) + + def __len__(self): + """__len__ is overloaded to access the underlying numpy data""" + return self.data.__len__() + + def __copy__(self): + """ + Returns a shallow copy of the current object. + """ + # Create the simplest possible dummy object and then overwrite it + obj = self._allocate() + object.__setattr__(obj, 'dimensions', self.dimensions) + object.__setattr__(obj, 'data', self.data) + object.__setattr__(obj, 'attributes', self.attributes) + return obj + + def __deepcopy__(self, memo=None): + """ + Returns a deep copy of the current object. + + memo does nothing but is required for compatability with copy.deepcopy + """ + # Create the simplest possible dummy object and then overwrite it + obj = self._allocate() + # tuples are immutable + object.__setattr__(obj, 'dimensions', self.dimensions) + object.__setattr__(obj, 'data', self.data[:].copy()) + object.__setattr__(obj, 'attributes', self.attributes.copy()) + return obj + + def __eq__(self, other): + if self.dimensions != other.dimensions or \ + (self.data.tostring() != other.data.tostring()): + return False + if not self.attributes == other.attributes: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + """Create a ncdump-like summary of the object""" + summary = ["dimensions:"] + # prints dims that look like: + # dimension = length + dim_print = lambda d, l : "\t%s : %s" % (conventions.pretty_print(d, 30), + conventions.pretty_print(l, 10)) + # add each dimension to the summary + summary.extend([dim_print(d, l) for d, l in zip(self.dimensions, self.shape)]) + summary.append("\ndtype : %s" % (conventions.pretty_print(self.dtype, 8))) + summary.append("\nattributes:") + # attribute:value + summary.extend(["\t%s:%s" % (conventions.pretty_print(att, 30), + conventions.pretty_print(val, 30)) + for att, val in self.attributes.iteritems()]) + # create the actual summary + return '\n'.join(summary) + + def views(self, slicers): + """Return a new Variable object whose contents are a view of the object + sliced along a specified dimension. + + Parameters + ---------- + slicers : {dim: slice, ...} + A dictionary mapping from dim to slice, dim represents + the dimension to slice along slice represents the range of the + values to extract. + + Returns + ------- + obj : Variable object + The returned object has the same attributes and dimensions + as the original. Data contents are taken along the + specified dimension. Care must be taken since modifying (most) + values in the returned object will result in modification to the + parent object. + + See Also + -------- + view + take + """ + slices = [slice(None)] * self.data.ndim + for i, dim in enumerate(self.dimensions): + if dim in slicers: + slices[i] = slicers[dim] + # Shallow copy + obj = copy.copy(self) + object.__setattr__(obj, 'data', self.data[slices]) + return obj + + def view(self, s, dim): + """Return a new Variable object whose contents are a view of the object + sliced along a specified dimension. + + Parameters + ---------- + s : slice + The slice representing the range of the values to extract. + dim : string + The dimension to slice along. If multiple dimensions equal + dim (e.g. a correlation matrix), then the slicing is done + only along the first matching dimension. + + Returns + ------- + obj : Variable object + The returned object has the same attributes and dimensions + as the original. Data contents are taken along the + specified dimension. Care must be taken since modifying (most) + values in the returned object will result in modification to the + parent object. + + See Also + -------- + take + """ + return self.views({dim : s}) + + def take(self, indices, dim): + """Return a new Variable object whose contents are sliced from + the current object along a specified dimension + + Parameters + ---------- + indices : array_like + The indices of the values to extract. indices must be compatible + with the ndarray.take() method. + dim : string + The dimension to slice along. If multiple dimensions equal + dim (e.g. a correlation matrix), then the slicing is done + only along the first matching dimension. + + Returns + ------- + obj : Variable object + The returned object has the same attributes and dimensions + as the original. Data contents are taken along the + specified dimension. + + See Also + -------- + numpy.take + """ + indices = np.asarray(indices) + if indices.ndim != 1: + raise ValueError('indices should have a single dimension') + # When dim appears repeatedly in self.dimensions, using the index() + # method gives us only the first one, which is the desired behavior + axis = list(self.dimensions).index(dim) + # Deep copy + obj = copy.deepcopy(self) + # In case data is lazy we need to slice out all the data before taking. + object.__setattr__(obj, 'data', self.data[:].take(indices, axis=axis)) + return obj + +class LazyVariableData(object): + """ + This object wraps around a Variable object (though + it only really makes sense to use it with a class that + extends variable.Variable). The result mascarades as + variable data, but doesn't actually try accessing the + data until indexing is attempted. + + For example, imagine you have some variable that was + derived from an opendap dataset, 'nc'. + + var = nc['massive_variable'] + + if you wanted to check the data type of var: + + var.data.dtype + + you would find that it might involve downloading all + of the actual data, then inspecting the resulting + numpy array. But with this wrapper calling: + + nc['large_variable'].data.someattribute + + will first inspect the Variable object to see if it has + the desired attribute and only then will it suck down the + actual numpy array and request 'someattribute'. + """ + def __init__(self, lazy_variable): + self.lazyvar = lazy_variable + + def __eq__(self, other): + return self.lazyvar[:] == other + + def __ne__(self, other): + return self.lazyvar[:] != other + + def __getitem__(self, key): + return self.lazyvar[key] + + def __setitem__(self, key, value): + if not isinstance(self.lazyvar, Variable): + self.lazyvar = Variable(self.lazyvar.dimensions, + data = self.lazyvar[:], + dtype = self.lazyvar.dtype, + shape = self.lazyvar.shape, + attributes = self.lazyvar.attributes) + self.lazyvar.__setitem__(key, value) + + def __getattr__(self, attr): + """__getattr__ is overloaded to selectively expose some of the + attributes of the underlying lazy variable""" + if hasattr(self.lazyvar, attr): + return getattr(self.lazyvar, attr) + else: + return getattr(self.lazyvar[:], attr) \ No newline at end of file diff --git a/test/test_data.py b/test/test_data.py index 855b32cea03..fc9023ab954 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -6,7 +6,7 @@ from copy import deepcopy from cStringIO import StringIO -from polyglot import Dataset, Variable +from polyglot import Dataset, Variable, backends _dims = {'dim1':100, 'dim2':50, 'dim3':10} _vars = {'var1':['dim1', 'dim2'], @@ -16,12 +16,13 @@ _testvar = sorted(_vars.keys())[0] _testdim = sorted(_dims.keys())[0] -def test_data(): - obj = Dataset() +def create_test_data(store=None): + obj = Dataset(store=store) obj.create_dimension('time', 10) for d, l in _dims.items(): obj.create_dimension(d, l) - var = obj.create_variable(name=d, dims=(d,), data=np.arange(l), + var = obj.create_variable(name=d, dims=(d,), + data=np.arange(l, dtype=np.int32), attributes={'units':'integers'}) for v, dims in _vars.items(): var = obj.create_variable(name=v, dims=tuple(dims), @@ -31,8 +32,11 @@ def test_data(): class DataTest(unittest.TestCase): + def get_store(self): + return None + def test_iterator(self): - data = test_data() + data = create_test_data(self.get_store()) # iterate over the first dim iterdim = _testdim for t, sub in data.iterator(dim=iterdim): @@ -51,7 +55,7 @@ def test_iterator(self): self.assertTrue((data[_testvar].data != -71).all()) def test_iterarray(self): - data = test_data() + data = create_test_data(self.get_store()) # iterate over the first dim iterdim = _testdim for t, d in data.iterarray(dim=iterdim, var=_testvar): @@ -127,7 +131,7 @@ def test_coordinate(self): attributes = {'foo': 'bar'} a.create_coordinate('x', data=vec, attributes=attributes) self.assertTrue('x' in a.coordinates) - self.assertTrue(a.coordinates['x'] == a.dimensions['x']) + self.assertTrue(a.coordinates['x'] == a.variables['x']) b = Dataset() b.create_dimension('x', vec.size) b.create_variable('x', dims=('x',), data=vec, attributes=attributes) @@ -187,7 +191,7 @@ def test_attributes(self): self.assertRaises(ValueError, b.attributes.__setitem__, 'foo', dict()) def test_view(self): - data = test_data() + data = create_test_data(self.get_store()) slicedim = _testdim s = slice(None, None, 2) ret = data.view(s=s, dim=slicedim) @@ -205,10 +209,10 @@ def test_view(self): if slicedim in data[v].dimensions: slice_list = [slice(None)] * data[v].data.ndim slice_list[data[v].dimensions.index(slicedim)] = s - expected = data[v].data[slice_list] + expected = data[v].data[slice_list][:] else: - expected = data[v].data - actual = ret[v].data + expected = data[v].data[:] + actual = ret[v].data[:] np.testing.assert_array_equal(expected, actual) # Test that our view accesses the same underlying array actual.fill(np.pi) @@ -219,7 +223,7 @@ def test_view(self): s=slice(100, 200), dim=slicedim) def test_views(self): - data = test_data() + data = create_test_data(self.get_store()) data.create_variable('var4', ('dim1', 'dim1'), data = np.empty((data.dimensions['dim1'], @@ -257,7 +261,7 @@ def test_views(self): {'not_a_dim': slice(0, 2)}) def test_take(self): - data = test_data() + data = create_test_data(self.get_store()) slicedim = _testdim # using a list ret = data.take(indices=range(2, 5), dim=slicedim) @@ -282,7 +286,7 @@ def test_take(self): expected = data[v].data.take( indices, axis=data[v].dimensions.index(slicedim)) else: - expected = data[v].data + expected = data[v].data[:] actual = ret[v].data np.testing.assert_array_equal(expected, actual) # Test that our take is a copy @@ -295,7 +299,7 @@ def test_take(self): dim=slicedim) def test_squeeze(self): - data = test_data() + data = create_test_data(self.get_store()) singleton = data.take([1], 'dim2') squeezed = singleton.squeeze('dim2') assert not 'dim2' in squeezed.dimensions @@ -304,7 +308,7 @@ def test_squeeze(self): squeezed[x].data) def test_select(self): - data = test_data() + data = create_test_data(self.get_store()) ret = data.select(_testvar) np.testing.assert_array_equal(data[_testvar].data, ret[_testvar].data) @@ -312,7 +316,7 @@ def test_select(self): self.assertRaises(KeyError, data.select, (_testvar, 'not_a_var')) def test_copy(self): - data = test_data() + data = create_test_data(self.get_store()) var = data.variables[_testvar] var.attributes['foo'] = 'hello world' var_copy = var.__deepcopy__() @@ -328,7 +332,7 @@ def test_copy(self): self.assertNotEqual(id(var.attributes), id(var_copy.attributes)) def test_rename(self): - data = test_data() + data = create_test_data(self.get_store()) newnames = {'var1':'renamed_var1', 'dim2':'renamed_dim2'} renamed = data.renamed(newnames) @@ -344,12 +348,57 @@ def test_rename(self): if name in dims: dims[dims.index(name)] = newname self.assertEqual(dims, list(renamed.variables[k].dimensions)) - np.testing.assert_array_equal(v.data, renamed.variables[k].data) + np.testing.assert_array_equal(v.data[:], renamed.variables[k].data[:]) self.assertTrue('var1' not in renamed.variables) self.assertTrue('var1' not in renamed.dimensions) self.assertTrue('dim2' not in renamed.variables) self.assertTrue('dim2' not in renamed.dimensions) +class NetCDF4DataTest(DataTest): + + def get_store(self): + tmp_file = './delete_me.nc' + if os.path.exists(tmp_file): + os.remove(tmp_file) + return backends.NetCDF4DataStore(tmp_file, mode='w') + + # Views on NetCDF4 objects result in copies of the arrays + # since the netCDF4 package requires data to live on disk + def test_view(self): + pass + + def test_views(self): + pass + + def test_iterarray(self): + pass + + # TODO: select isn't working for netCDF4 yet. + def test_select(self): + pass + +class ScipyDataTest(DataTest): + + def get_store(self): + fobj = StringIO() + return backends.ScipyDataStore(fobj, 'w') + +class StoreTest(unittest.TestCase): + + def test_translate_consistency(self): + + store = backends.InMemoryDataStore() + expected = create_test_data(store) + + mem_nc = deepcopy(expected) + self.assertTrue(isinstance(mem_nc.store, backends.InMemoryDataStore)) + + fobj = StringIO() + actual = Dataset(store=backends.ScipyDataStore(fobj, 'w')) + mem_nc.translate(actual) + + self.assertTrue(actual == expected) + if __name__ == "__main__": unittest.main()