Source code for pywatershed.base.data_model

import warnings
from copy import deepcopy
from typing import Iterable, Literal

import cftime
import netCDF4 as nc4
import numpy as np
import xarray as xr

from ..constants import fileish, fill_values_dict, np_type_to_netcdf_type_dict
from .accessor import Accessor

# This file defines the data model for pywatershed. It is called a
# "dataset_dict" and has a invertible mapping with non-hierarchical netcdf
# or xarray datasets.

# Show a basic example of how an xr_dd is changed to a dd.
# What is the difference


# TODO: what about hierarchical/groups in netcdf files?


# This is what a dataset_dict looks like. Metadata for coord and data_vars
# is found in metadata.
# These must always be deepcopied.
template_dd = {
    "dims": {},
    "coords": {},
    "data_vars": {},
    "metadata": {},
    "encoding": {},
}

template_xr_dd = {
    "attrs": {},
    "dims": {},
    "coords": {},
    "data_vars": {},
    "encoding": {},
}


# Note: Methods do not deep copy by default, but not all references may be
# preserved. use with caution and test.


[docs] class DatasetDict(Accessor): """DatasetDict: a data model following NetCDF-like conventions This is the core class in the data model adopted by pywatershed. The DatasetDict handles dimensions, coordinates, data, and metadata in a way like `NetCDF <https://www.unidata.ucar.edu/software/netcdf/>`__ and `xarray <https://docs.xarray.dev/en/stable/>`__ and provides invertible mappings between both the `netCDF4 <https://unidata.github.io/netcdf4-python/>`__ and `xarray <https://docs.xarray.dev/en/stable/>`__ Python packages. Where metadata is typically stored on a variable in NetCDF and in xarray, a DatasetDict maintains metadata in dictionary collocated with coordinate and data variables. The data model is a DatasetDict with dims, coords, data_vars, and metadata keys. The dims track the length of each dimension. The coordinates are the discrete locations along dimensions or sets of dimensions. The data_vars contain the data located on dims and coordinates. The metadata describes the relationship between both coords and data_vars and their dims. Together the coords and data_vars are the variables of the DatasetDict. All keys in the variables must be present in the metadata dictionary and each ke contains two more keys: dims and attrs. The dims is a tuple of the variable's dimensions and attrs are more general attributes. When a NetCDF file is read from disk, it has encoding properties that may come along. Alternatively, encodings may be specified before writing to file. Args: dims: A dictionary of pairs of `dim_names: dim_len` where `dim_len` is an integer value. coords: A dictionary of pairs of `coord_names: coord_data` where `coord_data` is an np.ndarray. data_vars: A dictionary of pairs of `var_names: var_data` where `coord_data` is an np.ndarray. metadata: For all names in `coords` and `data_vars`, metadata entries with the required fields: - dims: tuple of names in dim, - attrs: dictionary whose values may be strings, ints, floats The metadata argument may also contain a special `global` key paired with a dictionary of global metadata of arbitrary name and values of string, integer, or float types. encoding: (to document) validate: A bool that defaults to True, enforcing the consistency of the supplied dictionaries See Also -------- pywatershed.Parameters Examples --------- .. # This code is commented, copy and paste in to python, then paste the # output below to keep it clean from pprint import pprint import pywatershed as pws import numpy as np coords = { 'time': np.arange( '2005-02-01', '2005-02-03', dtype='datetime64[D]' ), 'space': np.arange(3) } dims = {'ntime': len(coords['time']), 'nspace': len(coords['space'])} data = {'precip': 10 * np.random.rand(dims['ntime'], dims['nspace'])} metadata = { "time": {"dims": ("ntime",), "attrs": {"description": "days"}}, "space": { "dims": ("nspace",), "attrs": {"description": "points of interest"}, }, "precip": { "dims": ( "ntime", "nspace", ), "attrs": { "description": "precipitation rate of all phases", "units": "mm/day", }, }, } dd = pws.base.DatasetDict( dims=dims, coords=coords, data_vars=data, metadata=metadata ) dd.dims.keys() dd.variables.keys() ds = dd.to_xr_ds() print(ds) >>> from pprint import pprint >>> import pywatershed as pws >>> import numpy as np >>> coords = { ... "time": np.arange( ... "2005-02-01", "2005-02-03", dtype="datetime64[D]" ... ), ... "space": np.arange(3), ... } >>> dims = {"ntime": len(coords["time"]), "nspace": len(coords["space"])} >>> data = {"precip": 10 * np.random.rand(dims["ntime"], dims["nspace"])} >>> metadata = { ... "time": {"dims": ("ntime",), "attrs": {"description": "days"}}, ... "space": { ... "dims": ("nspace",), ... "attrs": {"description": "points of interest"}, ... }, ... "precip": { ... "dims": ( ... "ntime", ... "nspace", ... ), ... "attrs": { ... "description": "precipitation rate of all phases", ... "units": "mm/day", ... }, ... }, ... } >>> dd = pws.base.DatasetDict( ... dims=dims, coords=coords, data_vars=data, metadata=metadata ... ) >>> dd.dims.keys() dict_keys(['ntime', 'nspace']) >>> dd.variables.keys() dict_keys(['time', 'space', 'precip']) >>> ds = dd.to_xr_ds() >>> print(ds) <xarray.Dataset> Dimensions: (ntime: 2, nspace: 3) Coordinates: time (ntime) datetime64[ns] 2005-02-01 2005-02-02 space (nspace) int64 0 1 2 Dimensions without coordinates: ntime, nspace Data variables: precip (ntime, nspace) float64 8.835 5.667 9.593 7.239 3.92 0.4195 """
[docs] def __init__( self, dims: dict[int] = None, coords: dict = None, data_vars: dict = None, metadata: dict = None, encoding: dict = None, validate: bool = True, ) -> None: if dims is None: dims = {} if coords is None: coords = {} if data_vars is None: data_vars = {} if metadata is None: metadata = {} if encoding is None: encoding = {} self._data_vars = data_vars self._dims = dims self._coords = coords self._metadata = metadata self._encoding = encoding for cat_data in ["_data_vars", "_coords"]: for kk, vv in self[cat_data].items(): if not isinstance(vv, np.ndarray): msg = f"Coercing {cat_data[1:]}: {kk} to an np.ndarray." warnings.warn(msg) self[cat_data][kk] = np.array(vv) if "global" not in self._metadata.keys(): self._metadata["global"] = {} if validate: self.validate() return
@property def dims(self, copy=False) -> dict: """Return the dimensions""" if copy: return deepcopy(self._dims) return self._dims @property def coords(self, copy=False) -> dict: """Return the coordinates""" if copy: return deepcopy(self._coords) return self._coords @property def data_vars(self, copy=False) -> dict: """Return the data_vars.""" if copy: return deepcopy(self._data_vars) return self._data_vars @property def variables(self, copy=False) -> dict: """Return coords and data_vars together""" vars = {**self._coords, **self._data_vars} if copy: vars = deepcopy(vars) return vars @property def metadata(self, copy=False) -> dict: """Return the metadata""" if copy: return deepcopy(self._metadata) return self._metadata @property def encoding(self, copy=False) -> dict: """Return the encoding""" if copy: return deepcopy(self._encoding) return self._encoding @property def data(self, copy=False) -> dict: """Return a dict of dicts: dims, coords, data_vars, metadata, encoding Args: copy: boolean if a deepcopy is desired Returns: A dict of dicts containing all the data """ data = { "dims": self._dims, "coords": self._coords, "data_vars": self._data_vars, "metadata": self._metadata, "encoding": self._encoding, } if copy: return deepcopy(data) return data def _keys(self) -> list: return ["dims", "coords", "data_vars", "metadata", "encoding"] # are __repr__ and __str__ better than default? # def __repr__(self): # return pprint.pformat( # { # "dims": self.dims, # "coords": self.coords, # "data_vars": self.data_vars, # "metadata": self.metadata, # "encoding": self.encoding, # } # ) # def __str__(self): # return
[docs] @classmethod def from_dict(cls, dict_in, copy=False): """Return this class from a passed dictionary. Parameters ---------- dict_in : dict A dictionary from which to create an instance of this class copy : bool, optional If True, the passed dictionary will be deep copied. Default is False. Returns ------- DatasetDict An object of this class. """ if copy: return cls(**deepcopy(dict_in)) return cls(**dict_in)
@property def spatial_coord_names(self) -> dict: """Return the spatial coordinate names. Args: None Returns: Dictionary of spatial coordinates with names. """ attrs = self._metadata["global"] return {kk: vv for kk, vv in attrs.items() if "spatial" in kk}
[docs] @classmethod def from_ds(cls, ds): """Get this class from a dataset (nc4 or xarray).""" # detect typ as xr or nc4 if isinstance(ds, xr.Dataset): return cls(**xr_ds_to_dd(ds)) elif isinstance(ds, nc4.Dataset): return cls(**nc4_ds_to_dd(ds)) else: raise ValueError("Passed dataset neither from xarray nor netCDF4")
[docs] @classmethod def from_netcdf( cls, nc_file: fileish, use_xr: bool = False, encoding=False ) -> "DatasetDict": """Load this class from a netcdf file.""" # handle more than one file? if use_xr: return cls(**xr_ds_to_dd(nc_file, encoding=encoding)) else: return cls(**nc4_ds_to_dd(nc_file, use_xr_enc=encoding))
[docs] def to_xr_ds(self) -> xr.Dataset: """Export to an xarray Dataset""" return dd_to_xr_ds(self.data)
[docs] def to_xr_dd(self) -> dict: """Export to an xarray DatasetDict (xr.Dataset.to_dict()).""" return dd_to_xr_dd(self.data)
[docs] def to_nc4_ds(self, filename) -> None: """Export to a netcdf file via netcdf4""" return dd_to_nc4_ds(self.data, filename)
[docs] def to_netcdf(self, filename, use_xr=False) -> None: """Write parameters to a netcdf file""" if use_xr: self.to_xr_ds().to_netcdf(filename) else: self.to_nc4_ds(filename) return
[docs] def rename_dim(self, name_maps: dict, in_place: bool = True): """Rename dimensions.""" if not in_place: raise NotImplementedError for old_name, new_name in name_maps.items(): self.dims[new_name] = self.dims.pop(old_name) for mk, mv in self.metadata.items(): if "dims" in mv.keys() and old_name in mv["dims"]: dim_list = list(mv["dims"]) dim_list[dim_list.index(old_name)] = new_name mv["dims"] = tuple(dim_list) self.validate() return
[docs] def rename_var(self, name_maps: dict, in_place=True): """Rename variables.""" if not in_place: raise NotImplementedError for old_name, new_name in name_maps.items(): for cv in ["coords", "data_vars"]: if old_name in self[cv].keys(): self[cv][new_name] = self[cv].pop(old_name) for aa in ["metadata", "encoding"]: self[aa][new_name] = self[aa].pop(old_name) self.validate() return
[docs] def drop_var(self, var_names): """Drop variables""" if not isinstance(var_names, list): var_names = [var_names] for vv in var_names: for cv in ["coords", "data_vars"]: if vv in self[cv].keys(): del self[cv][vv] del self.metadata[vv] del self.encoding[vv] # todo: can any coords be dropped? # can any dims be dropped? coord_dims = self._get_var_dims(list(self.variables.keys())) coord_dims = set([ii for cc in coord_dims.values() for ii in cc]) dims_rm = set(self.dims.keys()).difference(coord_dims) for dd in dims_rm: del self.dims[dd] return
# TODO: add_var def _get_var_dims(self, var_names, data=False): """Get the dims of variables""" # dims will never return a dict or numpy array result = {} if not isinstance(var_names, list): var_names = [var_names] for vv in var_names: dim_names = self._metadata[vv]["dims"] if not data: result[vv] = dim_names else: dim_data = { kk: vv for kk, vv in self._dims.items() if kk in dim_names } result[vv] = dim_data return result def _get_dim_coords(self, dim_list, data=False, copy=False): """Given a set of dimensions, get the corresponding coords""" # if all of a coords dims are in supplied dims, take the coordinate coords_out = [] coord_dims = self._get_var_dims(list(self._coords.keys())) for c_name, c_dims in coord_dims.items(): # if c_dims is empty (a scalar coord), it automatically goes if len(set(c_dims) - set(dim_list)) == 0: coords_out += [c_name] if not data: return coords_out coord_data = { ck: cv for ck, cv in self._coords.items() if ck in coords_out } if copy: coord_data = deepcopy(coord_data) return coord_data
[docs] def subset( self, keys: Iterable, copy: bool = False, keep_global: bool = False, keep_global_metadata: bool = None, keep_global_encoding: bool = None, strict: bool = False, ) -> "DatasetDict": """Subset a DatasetDict to keys in data_vars or coordinates Args: keys: Iterable to subset on copy: bool to copy the input or edit it keep_global: bool that sets both keep_global_metadata and keep_global_encoding keep_global_metadata: bool retain the global metadata in the subset keep_global_encoding: bool retain the global encoding in the subset Returns: A subset Parameter object on the passed keys. """ # Instantiate the DatasetDict at end as deepcopy will be used # on the constructed subset dict (if requested) if not isinstance(keys, Iterable) or isinstance(keys, str): keys = [keys] for kk in keys: if kk not in self.variables.keys(): if not strict: continue msg = f"key '{kk}' not in this {type(self).__name__} object" raise KeyError(msg) if keep_global_metadata is None: keep_global_metadata = keep_global if keep_global_encoding is None: keep_global_encoding = keep_global subset = deepcopy(template_dd) subset["metadata"]["global"] = {} subset["encoding"]["global"] = {} if keep_global_metadata: subset["metadata"]["global"] = self.metadata["global"] if keep_global_encoding: subset["encoding"]["global"] = self.encoding["global"] for vv in self.variables.keys(): if vv not in keys: continue is_coord = vv in self._coords.keys() if is_coord: subset["coords"][vv] = self._coords[vv] else: subset["data_vars"][vv] = self._data_vars[vv] subset["metadata"][vv] = self._metadata[vv] if vv in subset["encoding"].keys(): subset["encoding"][vv] = self._encoding[vv] var_dim_data = self._get_var_dims(vv, data=True)[vv] # dims for dd in var_dim_data: if dd not in subset["dims"].keys(): # faster? subset["dims"][dd] = var_dim_data[dd] if not is_coord: # build coords from variables var_coord_data = self._get_dim_coords( list(var_dim_data.keys()), data=True, copy=copy ) for ck, cv in var_coord_data.items(): if ck not in subset["coords"].keys(): subset["coords"][ck] = cv # build metadata and encoding from coords and data_vars for cv in ["coords", "data_vars"]: for cc in subset[cv].keys(): for aa in ["metadata", "encoding"]: if cc in subset[aa].keys(): continue if aa == "encoding" and cc not in self[aa].keys(): continue subset[aa][cc] = self[aa][cc] # result = DatasetDict.from_dict(subset, copy=copy) # If this is in-place, then cant be a classmethod result = type(self).from_dict(subset, copy=copy) return result
[docs] def subset_on_coord( self, coord_name: str, where: np.ndarray, ) -> None: """Subset DatasetDict to a np.where along a named coordinate in-place Args: coord_name: string name of a coordinate where: the result of an np.where along that coordinate (or likewise constructed) Returns: None """ # only doing it in place for now # should we reall roll our own? # TODO: should almost work for 2+D? just linearizes np.where # except that dims should be droped with >1 if len(where) > 1: raise NotImplementedError("at least not tested") # add the where to the data_vars wh_data_name = "subset_inds" if wh_data_name in self.data_vars.keys(): raise NotImplementedError("more work needed to subset twice") # what are the dim names of of the coord coord_dims = self.metadata[coord_name]["dims"] # new dim for the numer of dims in the where subset_dims_dim = wh_data_name + "_n_dims" self.dims[subset_dims_dim] = len(where) # new var self.data_vars[wh_data_name] = np.array(where).transpose() self.metadata[wh_data_name] = { "dims": coord_dims + (subset_dims_dim,), "attrs": { "description": ( "Zero-based indices used to subset original data " f"coordinate '{coord_name}'" ), }, } self.encoding[wh_data_name] = {} # have to edit the dim AND all variables with this dim for ii, dd in enumerate(coord_dims): self.dims[dd] = len(where[ii]) dim_where = dict(zip(coord_dims, where)) for vk, vv in self.variables.items(): if vk == wh_data_name: continue var_dims = self.metadata[vk]["dims"] var_wh = tuple( [dim_where[dd] for dd in var_dims if dd in dim_where.keys()] ) if vk in self.coords.keys(): self["coords"][vk] = self.variables[vk][var_wh] else: self["data_vars"][vk] = self.variables[vk][var_wh] self.metadata[vk]["attrs"]["subset_on_coord"] = coord_name self.metadata[vk]["attrs"]["subset_inds_on_orig"] = wh_data_name return
[docs] def validate(self) -> None: """Check that a DatasetDict is internally consistent. Returns: None """ # required keys assert sorted(self.data.keys()) == sorted( ["dims", "coords", "data_vars", "metadata", "encoding"] ) # can not have same names in coords and data_vars common_keys = set(self.coords.keys()).intersection( set(self.data_vars.keys()) ) assert len(common_keys) == 0 # metadata and encoding keys against variable keys var_keys = set(self.variables.keys()) # all meta keys have to be in var keys meta_keys = set(self.metadata.keys()) assert meta_keys == var_keys.union(set(["global"])) # all enc_keys besides global have to be in var_keys enc_keys = set(self.encoding.keys()) enc_keys = enc_keys.difference(set(["global"])) assert enc_keys.intersection(var_keys) == enc_keys # check all vars dims exist var_dims = [self.metadata[kk]["dims"] for kk in self.variables.keys()] dims = set([dim for dims in var_dims for dim in dims]) for dd in dims: assert dd in self.dims.keys() # TODO: check dims lens equal data in the coordinates with those dims return
[docs] @classmethod def merge(cls, *dd_list, copy=True, del_global_src=True): """Merge a list of this class in to a single instance Args: dd_list: a list of object of this class copy: boolean if a deep copy of inputs is desired del_global_src: boolean to delete encodings' global source attribute prior to merging (as these often conflict) Returns: An object of this class. """ if del_global_src or copy: dd_list = [deepcopy(dd.data) for dd in dd_list] if del_global_src: for dd in dd_list: if "global" not in dd["encoding"]: continue if "source" in dd["encoding"]["global"]: del dd["encoding"]["global"]["source"] # <<< merged_dict = _merge_dicts(dd_list) else: merged_dict = _merge_dicts([deepcopy(dd.data) for dd in dd_list]) if copy: merged_dict = deepcopy(merged_dict) return cls.from_dict(merged_dict)
# DatasetDict # --------------------------- # module scope functions def _is_equal(aa, bb): # How sketchy is this? (honest question) try: np.testing.assert_equal(aa, bb) return True except: # noqa return False def _merge_dicts( dict_list: list[dict], conflicts: Literal["left", "warn", "error"] = "error", ): if not isinstance(dict_list, list): raise ValueError("argument 'dict_list' is not a list") merged = {} for dd in dict_list: for key, value in dd.items(): if key not in merged: merged[key] = value elif isinstance(value, dict) and isinstance(merged[key], dict): merged[key] = _merge_dicts( [value, merged[key]], conflicts=conflicts ) elif _is_equal(value, merged[key]): pass else: msg = ( f"Duplicate key '{key}' with non-identical data:\n" f" L={merged[key]}\n" f" R={value}\n" ) if conflicts == "error": raise ValueError(msg) elif conflicts == "warn": warnings.warn(msg) elif conflicts == "left": pass else: raise ValueError( f"Argument 'conflicts' can not be '{conflicts}'" ) return merged def xr_ds_to_dd(file_or_ds, schema_only=False, encoding=True) -> dict: """Xarray dataset to a pywatershed dataset dict The pyws data model moves metadata off the variables to a separate var_metadata dictionary with the same keys found in the union of the keys of coords and data_vars. """ if not isinstance(file_or_ds, xr.Dataset): xr_ds = xr.load_dataset(file_or_ds) else: xr_ds = file_or_ds if schema_only: data_arg = False else: data_arg = "array" dd = xr_ds.to_dict(data=data_arg, encoding=encoding) # before = xr_ds.time.values.dtype # after = dd["coords"]["time"]["data"].dtype # assert before == after dd = xr_dd_to_dd(dd) return dd def xr_dd_to_dd(xr_dd: dict) -> dict: dd = deepcopy(xr_dd) # asdf # Move the global encoding to a global key of itself dd["encoding"] = {"global": dd.get("encoding", {})} # rename data to metadata for var and coord var_metadata = dd.pop("data_vars") coord_metadata = dd.pop("coords") # create empty data dicts and move the data out of the metadata # and move the encoding to encoding[var] dd["data_vars"] = {} for key, val in var_metadata.items(): dd["data_vars"][key] = val.pop("data") if "encoding" in val.keys(): dd["encoding"][key] = val.pop("encoding") else: dd["encoding"][key] = {} dd["coords"] = {} for key, val in coord_metadata.items(): dd["coords"][key] = val.pop("data") if "encoding" in val.keys(): dd["encoding"][key] = val.pop("encoding") else: dd["encoding"][key] = {} dd["metadata"] = {**coord_metadata, **var_metadata} dd["metadata"]["global"] = dd.pop("attrs") return dd def dd_to_xr_dd(dd: dict) -> dict: dd = deepcopy(dd) # remove metadata and encoding from the dict/model meta = dd.pop("metadata") encoding = dd.pop("encoding") # loop over meta data, putting it back on the data_vars and/or coords # and moving the data to "data" dd["attrs"] = meta.pop("global", {}) dd["encoding"] = encoding.pop("global", {}) for key, val in meta.items(): # coordinate or variable? cv = None if key in dd["data_vars"].keys(): cv = "data_vars" elif key in dd["coords"].keys(): cv = "coords" data_vals = dd[cv][key] if np.issubdtype(data_vals.dtype, np.datetime64): # conversion to datetime64[ns] silences xr warnings # but should be able to be removed in the future data_vals = data_vals.astype("datetime64[ns]") key_enc = {} if key in encoding.keys(): key_enc = encoding[key] dd[cv][key] = { **val, "data": data_vals, "encoding": key_enc, } return dd def dd_to_xr_ds(dd: dict) -> xr.Dataset: """pywatershed dataset dict to xarray dataset The pyws data model moves metadata off the variables to a separate var_metadata dictionary with the same keys found in the union of the keys of coords and data_vars. This maps the metadata back to the variables. """ return xr.Dataset.from_dict(dd_to_xr_dd(dd)) def _nc4_var_to_datetime64(var, attrs, encoding): """netCDF4 conversion of time to numpy.datetime64 based on metadata.""" if not (hasattr(var, "units") and "since" in var.units): return var[:], attrs, encoding # Check if the variable has a calendar attribute if hasattr(var, "calendar"): time_data = nc4.num2date( var[:], var.units, var.calendar, only_use_cftime_datetimes=False ) else: time_data = nc4.num2date(var[:], var.units) if isinstance(time_data, cftime.real_datetime): time_data = np.datetime64(time_data) else: time_data = time_data.filled().astype("datetime64[us]") for aa in ["calendar", "units"]: if aa in attrs.keys(): encoding[aa] = attrs.pop(aa) return time_data, attrs, encoding def _datetime64_to_nc4_var(var, units, calendar): """vectorized conversion of numpy.datetime64 to a netcdf4 variable.""" # Based on what xarray does # https://github.com/pydata/xarray/blob/ # a1f5245a48146bd8fc5bdb07ef8ae6077d6e511c/xarray/coding/times.py#L687 from xarray.coding.times import decode_cf_datetime # This can take encoding info but we are using defaults. # (data, units, calendar) = decode_cf_datetime(var, units, calendar) # return {"data": data, "units": units, "calendar": calendar} data = decode_cf_datetime(var, units, calendar) return {"data": data} def nc4_ds_to_xr_dd(file_or_ds, xr_enc: dict = None) -> dict: """Convert a netCDF4 dataset to and xarray dataset dictionary""" if not isinstance(file_or_ds, nc4.Dataset): ds = nc4.Dataset(file_or_ds, "r") else: ds = file_or_ds # An empty xr_dd dictionary to hold the data xr_dd = deepcopy(template_xr_dd) # xr_dd["attrs"] = nc_file.__dict__ # ugly for attrname in ds.ncattrs(): xr_dd["attrs"][attrname] = ds.getncattr(attrname) for dimname, dim in ds.dimensions.items(): xr_dd["dims"][dimname] = len(dim) for varname, var in ds.variables.items(): # _Encoding is used for string encoding in nc4 var_encoding = {} if "_Encoding" in var.__dict__.keys(): var_encoding["_Encoding"] = var._Encoding data_dict = {"dims": var.dimensions} var_attrs = {} for attrname in var.ncattrs(): var_attrs[attrname] = var.getncattr(attrname) var_data, var_attrs, var_encoding = _nc4_var_to_datetime64( var, var_attrs, var_encoding, ) if isinstance(var_data, np.ma.core.MaskedArray): var_data = var_data.data for aa in ["_FillValue"]: if aa in var_attrs.keys(): var_encoding[aa] = var_attrs.pop(aa) data_dict["data"] = var_data data_dict["attrs"] = var_attrs data_dict["encoding"] = var_encoding if varname in ds.dimensions: xr_dd["coords"][varname] = data_dict else: xr_dd["data_vars"][varname] = data_dict ds.close() # have to promote coord variables here? all_coords = [ vv["attrs"].pop("coordinates") for vv in xr_dd["data_vars"].values() if "coordinates" in vv["attrs"].keys() ] all_coords = sorted(set(" ".join(all_coords).split(" "))) for cc in all_coords: if cc == "": continue xr_dd["coords"][cc] = xr_dd["data_vars"].pop(cc) # handle bools for meta in [xr_dd["coords"], xr_dd["data_vars"]]: for kk, vv in meta.items(): if "dtype" in vv["attrs"].keys(): dtype = vv["attrs"]["dtype"] if dtype == "bool": vv["data"] = vv["data"].astype("bool") _ = vv["attrs"].pop("dtype") # bring in the encoding information using xarray (cheating?) if xr_enc: xr_dd["encoding"] = {**xr_dd["encoding"], **xr_enc.pop("global")} for cc in xr_dd["coords"].keys(): if cc in xr_enc.keys(): xr_dd["coords"][cc]["encoding"] = { **xr_dd["coords"][cc]["encoding"], **xr_enc.pop(cc), } for vv in xr_dd["data_vars"].keys(): if vv in xr_enc.keys(): xr_dd["data_vars"][vv]["encoding"] = { **xr_dd["data_vars"][vv]["encoding"], **xr_enc.pop(vv), } return xr_dd def _get_xr_encoding(nc_file) -> dict: ds = xr.load_dataset(nc_file) encoding = {} encoding["global"] = ds.encoding for vv in ds.variables: encoding[vv] = ds[vv].encoding return encoding def nc4_ds_to_dd( nc4_file_ds, subset: np.ndarray = None, use_xr_enc=True ) -> dict: """netCDF4 dataset to a pywatershed dataset dict.""" xr_enc = None if not isinstance(nc4_file_ds, nc4.Dataset): if use_xr_enc: xr_enc = _get_xr_encoding(nc4_file_ds) nc4_file_ds = nc4.Dataset(nc4_file_ds) else: if use_xr_enc: raise ValueError( "Pass a file and not an nc4.Dataset to use_xr_enc argument" ) xr_dd = nc4_ds_to_xr_dd(nc4_file_ds, xr_enc=xr_enc) dd = xr_dd_to_dd(xr_dd) return dd def dd_to_nc4_ds(dd, nc_file): """nc4_ds is a bit of a misnomer since it's on disk, dd_to_nc4""" # import cftime from xarray.coding.times import encode_cf_datetime dd = deepcopy(dd) # work from xrarray's dict representation of a dataset xr_dd = dd_to_xr_dd(dd) del dd # create a new netCDF4 file with nc4.Dataset(nc_file, "w") as ds: ds.set_fill_on() for key, value in xr_dd["attrs"].items(): setattr(ds, key, value) for dim, size in xr_dd["dims"].items(): ds.createDimension(dim, size) for coord_name, values in xr_dd["coords"].items(): enc = values["encoding"] # handle time encoding if np.issubdtype(values["data"].dtype, np.datetime64): dates_enc, units, calendar = encode_cf_datetime( values["data"].astype("datetime64[ns]"), enc.get("units", None), enc.get("calendar", None), ) # TODO: what if .astype("datetime64[ns]") fails? # this would be a manual way that's looped and that # dosent guess at units and calendar # cf_dates = cftime.date2num( # dates, # units=enc.get("units"), # calendar=enc.get("calendar"), # ) values["data"] = dates_enc values["attrs"]["units"] = units values["attrs"]["calendar"] = calendar for kk in ["units", "calendar"]: if kk in enc.keys(): del enc[kk] var_type = values["attrs"].get("type", values["data"].dtype) var = ds.createVariable( coord_name, var_type, dimensions=xr_dd["coords"][coord_name]["dims"], fill_value=enc.get("_FillValue", None), # This is not complete. Defaults from nc4 zlib=enc.get("zlib", False), complevel=enc.get("complevel", 4), shuffle=enc.get("shuffle", True), contiguous=enc.get("contiguous", False), fletcher32=enc.get("fletcher32", False), chunksizes=enc.get("chunksizes", None), ) var[:] = values["data"] var.setncatts(values["attrs"]) for var_name, values in xr_dd["data_vars"].items(): enc = values["encoding"] is_bool = values["data"].dtype == "bool" np_type = values["data"].dtype nc_type = np_type_to_netcdf_type_dict[np_type] var_type = values["attrs"].get("type", nc_type) if is_bool: var_type = "i1" default_fill = fill_values_dict[np_type] var = ds.createVariable( var_name, var_type, dimensions=values["dims"], fill_value=enc.get("_FillValue", default_fill), # This is not complete. Defaults from nc4 zlib=enc.get("zlib", False), complevel=enc.get("complevel", 4), shuffle=enc.get("shuffle", True), contiguous=enc.get("contiguous", False), fletcher32=enc.get("fletcher32", False), chunksizes=enc.get("chunksizes", None), ) if is_bool: var[:] = values["data"].astype("int8") else: var[:] = values["data"] var.setncatts(values["attrs"]) # solve coords for this var: any coord which has any of its dims coords = [] # these rules are a bit opaque coords_just_cuz = ["reference_time"] coords_no_way = ["time"] for c_name, c_val in xr_dd["coords"].items(): if c_name in coords_no_way: continue common = set(values["dims"]).intersection(set(c_val["dims"])) if len(common) or c_name in coords_just_cuz: coords += [c_name] if len(coords): var.coordinates = " ".join(sorted(coords)) if is_bool: var.setncattr("dtype", "bool") return def open_datasetdict(nc_file: fileish, use_xr=True): """Convenience method for opening a DatasetDict. Args: nc_file: the file containing the DatasetDict. use_xr: Use xarray or NetCDF4 for opening the NetCDF file? """ return DatasetDict.from_netcdf(nc_file, use_xr=use_xr)