import datetime as dt
import pathlib as pl
from math import ceil
from typing import Union
import netCDF4 as nc4
import numpy as np
import xarray as xr
from ..base.accessor import Accessor
from ..base.meta import meta_dimensions, meta_netcdf_type
from ..constants import np_type_to_netcdf_type_dict
from ..utils.time_utils import datetime_doy
fileish = Union[str, pl.Path]
listish = Union[list, tuple]
arrayish = Union[list, tuple, np.ndarray]
ATOL = np.finfo(np.float32).eps
# JLM TODO: the implied time dimension seems like a bad idea, it should be
# an argument.
class NetCdfRead(Accessor):
"""NetCDF file reader (for input/forcing data)
Args:
name: the netcdf file path to open
start_time: optional np.datetime64 which is the start of the simulation
end_time: optional np.datetime64 which is the end of the simulation
nc_read_vars: a subset of available variables to read from the file
load_n_times: optional integer for the length of time to load into memory
from file see load_n_time_batches for more details. default is None. A
value of -1 gives the same result as load_n_time_batches = 1.
load_n_time_batches: optional integer for the number to batches (time
partitions) of the input data to use. Default value is 1 whic loads all
the input data on the first time advance. This may not work well for
domains large in space and/or time as it will require large amounts of
memory. These load options exist to optimize IO patterns against memory
usage: specify EXACTLY ONE of load_n_times or load_ntime_batches,
whichever is more convenient for you. If both _load_times and
_load_n_time_batches are None, then no time batching is used. This has
proven an inefficient pattern. The time batching is not implemented for
DOY (cyclic) variables, only for variables with time dimension "time".
"""
def __init__(
self,
name: fileish,
start_time: np.datetime64 = None,
end_time: np.datetime64 = None,
nc_read_vars: list = None,
load_n_times: int = None,
load_n_time_batches: int = 1,
) -> "NetCdfRead":
self.name = "NetCdfRead"
self._nc_file = name
self._nc_read_vars = nc_read_vars
self._start_time = start_time
self._end_time = end_time
if (load_n_times is not None) and (load_n_time_batches is not None):
msg = "Can only specify one of load_n_times or load_ntime_batches"
raise ValueError(msg)
self._load_n_times = load_n_times
self._load_n_time_batches = load_n_time_batches
self._open_nc_file()
self._itime_step = {}
for variable in self.variables:
self._itime_step[variable] = 0
def __del__(self):
self.close()
def close(self):
if self.dataset.isopen():
self.dataset.close()
def _open_nc_file(self):
self.dataset = nc4.Dataset(self._nc_file, "r")
self.ds_var_list = list(self.dataset.variables.keys())
if self._nc_read_vars is None:
self._nc_read_vars = self.ds_var_list
if "time" in self.dataset.variables:
self._time = (
nc4.num2date(
self.dataset.variables["time"][:],
units=self.dataset.variables["time"].units,
calendar="standard",
only_use_cftime_datetimes=False,
)
.filled()
.astype("datetime64[s]")
# JLM: the global time type as in cbh_utils, define somewhere
)
if self._start_time is None:
self._start_index = 0
else:
wh_start = np.where(self._time == self._start_time)
self._start_index = wh_start[0][0]
if self._end_time is None:
self._end_index = self._time.shape[0] - 1
else:
wh_end = np.where(self._time == self._end_time)
self._end_index = wh_end[0][0]
self._time = self._time[self._start_index : (self._end_index + 1)]
self._ntimes = self._end_index - self._start_index + 1
# time batching
# data are actually loaded in get_data
if self._load_n_time_batches is not None:
# use ceil because we want exactly the requested # of batches
self._load_n_times = ceil(
self._ntimes / self._load_n_time_batches
)
self._data_loaded = {}
elif self._load_n_times is not None:
# Use ceil to account for the remainder batch
if self._load_n_times == -1:
self._load_n_times = self._ntimes
self._load_time_batches = ceil(
self._ntimes / self._load_n_times
)
self._data_loaded = {}
# Note that if neither _load variables is specified, then no time
# batching is used
if "doy" in self.dataset.variables:
self._doy = self.dataset.variables["doy"][:].data
self._ntimes = self._doy.shape[0]
self._start_index = 0
self._end_index = 365
spatial_id_names = []
if "nhm_id" in self.ds_var_list:
spatial_id_names.append("nhm_id")
elif "hru_id" in self.ds_var_list:
spatial_id_names.append("hru_id")
elif "grand_id" in self.ds_var_list:
spatial_id_names.append("grand_id")
if "nhm_seg" in self.ds_var_list:
spatial_id_names.append("nhm_seg")
elif "hru_seg" in self.ds_var_list:
spatial_id_names.append("hru_seg")
# set default spatial id
if len(spatial_id_names) < 1:
spatial_id_names.append("hru_id")
# set spatial id dictionary
self._spatial_ids = {}
for spatial_id_name in spatial_id_names:
self._spatial_ids[spatial_id_name] = self.dataset.variables[
spatial_id_name
][:]
self._variables = [
name
for name in self.ds_var_list
if name != "time" and name not in spatial_id_names
]
return
@property
def ntimes(
self,
) -> int:
"""Get number of times in the netcdf file
Returns:
ntimes: number of times in the NetCDF file
"""
return self._ntimes
@property
def times(
self,
) -> np.ndarray:
"""Get the times in the NetCDF file
Returns:
data_times: numpy array of datetimes in the NetCDF file
"""
if hasattr(self, "_time"):
return self._time
elif hasattr(self, "_doy"):
return self._doy
else:
raise KeyError(f"{self.name} has neither _time nor _doy")
@property
def nhru(
self,
) -> int:
"""Get number of HRUs in the NetCDF file
Returns:
nhru_shape: number of HRUs in the NetCDF file
"""
hru_key = None
if "nhm_id" in self._spatial_ids.keys():
hru_key = "nhm_id"
elif "hru_id" in self._spatial_ids.keys():
hru_key = "hru_id"
if hru_key is not None:
nhru_shape = self._spatial_ids[hru_key].shape[0]
else:
nhru_shape = 0
return nhru_shape
@property
def nsegment(
self,
) -> int:
"""Get number of segments in the NetCDF file
Returns:
seg_shape: number of segments in the NetCDF file
"""
seg_key = None
if "nhm_seg" in self._spatial_ids.keys():
seg_key = "nhm_seg"
if seg_key is not None:
seg_shape = self._spatial_ids[seg_key].shape[0]
else:
seg_shape = 0
return seg_shape
@property
def spatial_ids(
self,
) -> np.ndarray:
"""Get the spatial IDs in the NetCDF file
Returns:
arr: numpy array with the spatial IDs in the NetCDF file
"""
return self._spatial_ids
@property
def variables(self):
"""Get a list of variable names
Returns:
variables: list of variable names, excluding the time and
nhru variables
"""
return self._variables
def all_time(self, variable):
return self.get_data(variable)
def get_data(
self,
variable: str,
itime_step: int = None,
) -> np.ndarray:
"""Get data for a variable
Args:
variable: variable name
itime_step: time step to return. If itime_step is None all of the
data for a variable is returned
Returns:
arr: numpy array with the data for a variable
"""
if variable not in self._nc_read_vars:
raise ValueError(
f"'{variable}' not in list of available variables"
)
if itime_step is None:
return self.dataset[variable][
self._start_index : (self._end_index + 1), :
]
else:
if itime_step >= self._ntimes:
raise ValueError(
f"requested time step {itime_step} but only "
+ f"{self._ntimes} time steps are available."
)
if hasattr(self, "_data_loaded"):
# load when needed, at the beginning of each batch
batch_index = itime_step % self._load_n_times
if batch_index == 0:
ith_batch = itime_step // self._load_n_times
# print(
# f"load batch "
# f"#{ith_batch}/{self._load_n_time_batches-1}: "
# f"{variable}"
# )
start_ind = self._start_index + (
ith_batch * self._load_n_times
)
end_ind = start_ind + self._load_n_times
self._data_loaded[variable] = self.dataset[variable][
start_ind:end_ind, :
]
return self._data_loaded[variable][batch_index, :]
else:
# no time batching
return self.dataset[variable][itime_step, :]
def advance(
self, variable: str, current_time: np.datetime64 = None
) -> np.ndarray:
"""Get the data for a variable for the next time step
Args:
variable: variable name
Returns:
arr: numpy array with the data for a variable for the current
time step
"""
if "time" in self.dataset.variables:
arr = self.get_data(
variable,
itime_step=self._itime_step[variable],
)
if "doy" in self.dataset.variables:
arr = self.get_data(
variable,
itime_step=datetime_doy(current_time) - 1,
)
self._itime_step[variable] += 1
return arr
class NetCdfWrite(Accessor):
def __init__(
self,
name: fileish,
coordinates: dict,
variables: listish,
var_meta: dict,
extra_coords: dict = None,
global_attrs: dict = None,
time_units: str = "days since 1970-01-01 00:00:00",
clobber: bool = True,
zlib: bool = True,
complevel: int = 4,
chunk_sizes: dict = {"time": 1, "hruid": 0},
):
from netCDF4 import stringtochar
"""Output the csv output data to a netcdf file
Args:
name: path for netcdf output file
extra_coords: A dictionary keyed by dimension with the values being
a dictionary of var_name: data pairs. Not for multi-dimensional
coordinates.
clobber: boolean indicating if an existing netcdf file should
be overwritten
zlib: boolean indicating if the data should be compressed
(default is True)
complevel: compression level (default is 4)
chunk_sizes: dictionary defining chunk sizes for the data
"""
if isinstance(variables, dict):
group_variables = []
for group, vars in variables.items():
for var_name in vars:
if group is None:
group_variables += [var_name]
else:
group_variables += [f"{group}/{var_name}"]
v2 = [
var_name
for group, vars in variables.items()
for var_name in vars
]
variables = v2
else:
group_variables = variables
self.dataset = nc4.Dataset(name, "w", clobber=clobber)
self.dataset.setncattr("Description", "pywatershed output data")
if extra_coords is None:
extra_coords = {}
if global_attrs is None:
global_attrs = {}
for att_key, att_val in global_attrs.items():
self.dataset.setncattr(att_key, att_val)
variable_dimensions = {}
nhru_coordinate = False
nsegment_coordinate = False
one_coordinate = False
nreservoirs_coordinate = False
nnodes_coordinate = False
for var_name in variables:
dimension_name = meta_dimensions(var_meta[var_name])
variable_dimensions[var_name] = dimension_name
if (
"nhru" in dimension_name
or "ngw" in dimension_name
or "nssr" in dimension_name
):
nhru_coordinate = True
if "nsegment" in dimension_name:
nsegment_coordinate = True
if "one" in dimension_name:
one_coordinate = True
if "nreservoirs" in dimension_name:
nreservoirs_coordinate = True
if "nnodes" in dimension_name:
nnodes_coordinate = True
if nhru_coordinate:
hru_ids = coordinates["nhm_id"]
if isinstance(hru_ids, (list, tuple)):
nhrus = len(hru_ids)
elif isinstance(hru_ids, np.ndarray):
nhrus = hru_ids.shape[0]
self.nhrus = nhrus
self.hru_ids = hru_ids
if nsegment_coordinate:
segment_ids = coordinates["nhm_seg"]
if isinstance(segment_ids, (list, tuple)):
nsegments = len(segment_ids)
elif isinstance(segment_ids, np.ndarray):
nsegments = segment_ids.shape[0]
self.nsegments = nsegments
self.segment_ids = segment_ids
if one_coordinate:
self.one_ids = coordinates["one"]
if nreservoirs_coordinate:
self.nreservoirs = len(coordinates["grand_id"])
if nnodes_coordinate:
self.nnodes = len(coordinates["node_coord"])
# Dimensions
# Time is an implied dimension in the netcdf file for most variables
# time is necessary if an alternative time
# dimenison does not appear in even one variable
doy_time_vars = [
"soltab_potsw",
"soltab_horad_potsw",
"soltab_sunhrs",
]
for var_name in variables:
if var_name in doy_time_vars:
continue
# None for the len argument gives an unlimited dim
self.dataset.createDimension("time", None)
self.time = self.dataset.createVariable("time", "f4", ("time",))
self.time.units = time_units
break
# similarly, if alternative time dimenions exist.. define them
for var_name in variables:
if var_name not in doy_time_vars:
continue
self.dataset.createDimension("doy", 366)
self.doy = self.dataset.createVariable("doy", "i4", ("doy",))
self.doy.units = "Day of year"
break
if nhru_coordinate:
self.dataset.createDimension("nhm_id", self.nhrus)
if nsegment_coordinate:
self.dataset.createDimension("nhm_seg", self.nsegments)
if one_coordinate:
self.dataset.createDimension("one", 1)
if nreservoirs_coordinate:
self.dataset.createDimension("grand_id", self.nreservoirs)
if nnodes_coordinate:
self.dataset.createDimension("node_coord", self.nnodes)
if nhru_coordinate:
self.hruid = self.dataset.createVariable(
"nhm_id", "i4", ("nhm_id",)
)
self.hruid[:] = np.array(self.hru_ids, dtype=int)
if nsegment_coordinate:
self.segid = self.dataset.createVariable(
"nhm_seg", "i4", ("nhm_seg",)
)
self.segid[:] = np.array(self.segment_ids, dtype=int)
if one_coordinate:
self.oneid = self.dataset.createVariable("one", "i4", ("one"))
self.oneid[:] = np.array(self.one_ids, dtype=int)
if nreservoirs_coordinate:
self.grandid = self.dataset.createVariable(
"grand_id", "i4", ("grand_id",)
)
self.grandid[:] = coordinates["grand_id"]
if nnodes_coordinate:
self.node_coord = self.dataset.createVariable(
"node_coord", "i4", ("node_coord",)
)
self.node_coord[:] = coordinates["node_coord"]
char_dims_created = []
for x_dim, x_data_dict in extra_coords.items():
for x_var_name, x_data in x_data_dict.items():
type = x_data.dtype
type_str = str(type)
dim = (x_dim,)
if "U" in type_str or "S" in type_str:
# https://unidata.github.io/netcdf4-python/#dealing-with-strings # noqa: E501
# S1 gives "char" type in the file whereas another
# number gives "string" type. The former is properly
# handled by xarray
nc_type = "S1"
# if it is a string array, convert it to a character array
# I dont understand the particulars here, may need more
# work
# Convert Unicode strings to UTF-8 byte strings first
if "U" in type_str:
x_data = np.char.encode(x_data, "utf-8")
char_array = stringtochar(x_data)
char_dim_len = char_array.shape[1]
dim_name = f"char{char_dim_len}"
if dim_name in char_dims_created:
continue
dim = (x_dim, dim_name)
_ = self.dataset.createDimension(
dimname=dim_name, size=char_dim_len
)
char_dims_created += [dim_name]
else:
nc_type = np_type_to_netcdf_type_dict[type]
# <
self[x_var_name] = self.dataset.createVariable(
varname=x_var_name, datatype=nc_type, dimensions=dim
)
if "S1" == nc_type:
self[x_var_name][:, :] = char_array
self[x_var_name]._Encoding = "utf-8"
else:
self[x_var_name][:] = x_data
self.variables = {}
for var_name, group_var_name in zip(variables, group_variables):
variabletype = meta_netcdf_type(var_meta[var_name])
if len(
set(["nhru", "ngw", "nssr"]).intersection(
set(variable_dimensions[var_name])
)
):
spatial_coordinate = "nhm_id"
elif "nsegment" in variable_dimensions[var_name]:
spatial_coordinate = "nhm_seg"
elif "one" in variable_dimensions[var_name]:
spatial_coordinate = "one"
elif "nreservoirs" in variable_dimensions[var_name]:
spatial_coordinate = "grand_id"
elif "nnodes" in variable_dimensions[var_name]:
spatial_coordinate = "node_coord"
else:
msg = (
"Undefined spatial coordinate name in "
f"{variable_dimensions[var_name]}"
)
raise ValueError(msg)
if var_name in doy_time_vars:
time_dim = "doy"
else:
time_dim = "time"
var_dims = (time_dim, spatial_coordinate)
self.variables[var_name] = self.dataset.createVariable(
group_var_name,
variabletype,
var_dims,
fill_value=nc4.default_fillvals[variabletype],
zlib=zlib,
complevel=complevel,
chunksizes=tuple(chunk_sizes.values()),
)
for key, val in var_meta[var_name].items():
if isinstance(val, dict):
continue
self.variables[var_name].setncattr(key, val)
# https://docs.xarray.dev/en/stable/user-guide/io.html#coordinates
var_encoding = []
for x_dim, x_data_dict in extra_coords.items():
if x_dim in var_dims:
var_encoding += x_data_dict.keys()
if len(var_encoding):
var_encoding = " ".join(var_encoding)
self.variables[var_name].setncattr("coordinates", var_encoding)
return
def __del__(self):
self.close()
return
def close(self):
if self.dataset.isopen():
self.dataset.close()
return
def add_simulation_time(self, itime_step: int, simulation_time: float):
self.time[itime_step] = nc4.date2num(simulation_time, self.time.units)
return
def add_data(
self, name: str, itime_step: int, current: np.ndarray
) -> None:
"""Add data for a time step to a NetCDF variable
Args:
name:
itime_step:
Returns:
"""
if name not in self.variables.keys():
raise KeyError(f"{name} not a valid variable name")
var = self.variables[name]
var[itime_step, :] = current[:]
return
def add_all_data(
self,
name: str,
data: np.ndarray,
time_data: np.ndarray,
time_coord: str = "time",
) -> None:
"""Add data to a NetCDF variable
Args:
name:
Returns:
"""
if name not in self.variables.keys():
raise KeyError(f"{name} not a valid variable name")
if time_coord == "time":
start_date = (
time_data[0].astype(dt.datetime).strftime("%Y-%m-%d %H:%M:%S")
)
self[time_coord].units = f"days since {start_date}"
self[time_coord][:] = nc4.date2num(
time_data.astype(dt.datetime),
units=self[time_coord].units,
calendar="standard",
)
else:
# currently just doy
self[time_coord][:] = time_data
self.variables[name][:, :] = data[:, :]
return
[docs]
def subset_netcdf_file(
file_name: Union[pl.Path, str],
new_file_name: Union[pl.Path, str],
start_time: np.datetime64 = None,
end_time: np.datetime64 = None,
coord_dim_name: str = None,
coord_dim_values_keep: np.ndarray = None,
) -> None:
"""Subset a netcdf file on to coordinate or dimension values.
Args:
file_name: The name/path of the input file.
new_file_name: The name/path of the output file.
start_time: Optional start time if a "time" coord is present.
end_time: Optional end time if a "time" coord is present.
coord_dim_name: Optional coord or dimension name to subset on.
coord_dim_values_keep: Optional values on the coord or dimension to
retain in teh subset.
This currently works for 1-D coordinates, more dimensions not tested.
Note: This uses the function
:func:`pywatershed.utils.netcdf_utils.subset_xr`
under the hood, which can be called if you want to subset xr.Datasets in
memory. There seem to beseveral edge cases lurking around here with zero
length dimensions and xarray's broadcasting rules. This function is a
convenience function because xarray's functionality is not ideal for our
use cases and is confusing with pitfalls. See
https://github.com/pydata/xarray/issues/8796
for additional discussion.
"""
ds = xr.load_dataset(file_name)
ds = subset_xr(
ds=ds,
start_time=start_time,
end_time=end_time,
coord_dim_name=coord_dim_name,
coord_dim_values_keep=coord_dim_values_keep,
)
ds.to_netcdf(new_file_name)
return
[docs]
def subset_xr(
ds: Union[xr.Dataset, xr.DataArray],
start_time: np.datetime64 = None,
end_time: np.datetime64 = None,
coord_dim_name: str = None,
coord_dim_values_keep: np.array = None,
) -> Union[xr.Dataset, xr.DataArray]:
"""Subset an xarray Dataset or DataArray on to coord or dim values.
Args:
start_time: Optional start time if a "time" coord is present.
end_time: Optional end time if a "time" coord is present.
coord_dim_name: Optional coord or dimension name to subset on.
coord_dim_values_keep: Optional values on the coord or dimension to
retain in teh subset.
This currently works for 1-D coordinates, more dimensions not tested.
To work with files rather than memory see
:func:`pywatershed.utils.netcdf_utils.subset_netcdf_file`.
Note: There seem to be several edge cases lurking around here with zero
length dimensions and xarray's broadcasting rules. This function is a
convenience function because xarray's functionality is not ideal for our
use cases and is confusing with pitfalls. See
https://github.com/pydata/xarray/issues/8796 for additional discussion.
"""
if isinstance(ds, xr.DataArray):
var_dims_orig = ds.dims
else:
var_dims_orig = {key: ds[key].dims for key in ds.variables}
if coord_dim_name is not None or coord_dim_values_keep is not None:
msg = (
"Neither or both of coord_dim_name and coord_dim_values_keep "
"must be supplied."
)
assert (
coord_dim_name is not None and coord_dim_values_keep is not None
), msg
# <
if coord_dim_name is not None:
msg = f"{coord_dim_values_keep=} not in {coord_dim_name=}"
assert ds[coord_dim_name].isin(coord_dim_values_keep).any(), msg
ds = ds.where(
ds[coord_dim_name].isin(coord_dim_values_keep), drop=True
)
if isinstance(ds, xr.DataArray):
dims_orig = set(var_dims_orig)
dims_new = set(ds.dims)
extra_dims = list(dims_new - dims_orig)
if len(extra_dims):
for dd in extra_dims:
ds = ds.isel({dd: 0}).squeeze()
else:
for var in list(ds.variables):
dims_orig = set(var_dims_orig[var])
dims_new = set(ds[var].dims)
extra_dims = list(dims_new - dims_orig)
# if "scalar" in dims_orig:
# asdf
if len(extra_dims):
# a headache to deal with when it broadcasts to a zero
# or non-zero length dimension
dim_dict = dict(zip(ds[var].dims, ds[var].shape))
extra_dim_lens = np.array(
[dim_dict[dd] for dd in extra_dims]
)
if (extra_dim_lens <= 1).all():
ds[var] = ds[var].squeeze(extra_dims)
else:
for dd in extra_dims:
if dd in ds[var].dims:
ds[var] = ds[var].isel({dd: 0}).squeeze()
# <<<
if start_time is not None or end_time is not None:
msg = "Neither or both of start_time and end_time must be supplied."
assert start_time is not None and end_time is not None, msg
# <
# does sel work correctly here?
if start_time is not None:
if "time" in ds.dims:
ds = ds.sel(time=slice(start_time, end_time))
return ds