Source code for pywatershed.base.budget

import pathlib as pl
from copy import deepcopy
from typing import Literal, Union
from warnings import warn

import netCDF4 as nc4
import numpy as np

from pywatershed.base.control import Control

from ..constants import zero
from ..utils.formatting import pretty_print
from ..utils.netcdf_utils import NetCdfWrite
from .accessor import Accessor
from .parameters import Parameters

# Todo
# * documentation
# * get time_units / time_step from control
# * decide on the fate of "from storage unit"
# * terminology: "balance" just means "storage change"
# * terminology: what is the term for the budget dosent close/zero?


[docs] class Budget(Accessor): """Budget class for mass and energy conservation. Currently no energy budget has been implmenented, todo. """
[docs] def __init__( self, control: Control, inputs: Union[list, dict], outputs: Union[list, dict], storage_changes: Union[list, dict], init_accumulations: dict = None, accum_start_time: np.datetime64 = None, units: dict = None, time_unit: str = "D", # TODO: get this from control description: str = None, rtol: float = 1e-5, atol: float = 1e-5, basis: Literal["unit", "global"] = "unit", imbalance_fatal: bool = False, ignore_nans: bool = False, unit_desc: str = "volumes", verbose: bool = True, ): self.name = "Budget" self.control = control self.inputs = self.init_component(inputs) self.outputs = self.init_component(outputs) self.storage_changes = self.init_component(storage_changes) self.units = units self.time_unit = time_unit self.description = description self.rtol = rtol self.atol = atol self.imbalance_fatal = imbalance_fatal self._ignore_nans = ignore_nans self._unit_desc = unit_desc self.verbose = verbose self.basis = basis self._output_netcdf = False self._inputs_sum = None self._outputs_sum = None self._storage_changes_sum = None self._balance = None self._accumulations = None self._accumulations_sum = None self._zero_sum = None self._time = self.control.current_time self._itime_step = self.control.itime_step # metadata all_vars = [list(self[cc].keys()) for cc in self.get_components()] all_vars = [x for xs in all_vars for x in xs] self.meta = self.control.meta.get_vars(all_vars) if len(self.meta) == len(all_vars): if self.units is None: all_units = [val["units"] for val in self.meta.values()] else: all_units = list(self.units.values()) # check consistent units if not (np.array(all_units) == all_units[0]).all(): msg = "Units not consistent over all terms" raise ValueError(msg) self.units = all_units[0] else: all_vars_0 = [vv[0] == "_" for vv in all_vars] all_vars_private = all(all_vars_0) if not all_vars_private or verbose: msg = ( f"Metadata unavailable for some Budget terms in {all_vars}" ) warn(msg) self.units = None # generate metadata for derived output variables self.output_vars_desc = { "inputs_sum": f"Sum of input fluxes ({self.basis})", "outputs_sum": f"Sum of output fluxes ({self.basis})", "storage_changes_sum": f"Sum of storage changes ({self.basis})", # "balance": # f"Balance of fluxes and storage changes ({self.basis})", } sum_keys = list(self.output_vars_desc.keys()) for kk in sum_keys: terms = self.terms[kk[0:-4]] if terms is None or not len(terms): del self.output_vars_desc[kk] for kk, desc in self.output_vars_desc.items(): terms = self.terms[kk[0:-4]] term = terms[0] if term in self.meta.keys(): term_meta = self.meta[term] self.meta[kk] = deepcopy(term_meta) self.meta[kk]["desc"] = desc self.meta[kk]["units"] = self.units elif verbose: msg = f"budget term {term} not available in metadata" warn(msg) # if (var == "balance") and (self.basis == "global"): # self.meta[var]["dimensions"] = {0: "one"} self.set_initial_accumulations(init_accumulations, accum_start_time) return
[docs] @staticmethod def init_component(data: Union[list, dict]) -> dict: if isinstance(data, dict): return data else: return {dd: None for dd in data}
[docs] def set(self, data: dict): """Set the data on the components after initialization Args: data: a dict of dicts with top level optional keys: [inputs, outputs, storage_changes]. Each of those is a dict with var: np.ndarray, eg. data = {'inputs': {'var': np.ndarray}} """ for comp_name, comp_dict in data.items(): for var_name, var_data in comp_dict.items(): if self[comp_name][var_name] is not None: msg = ( f"Component '{comp_name}' variable '{var_name}'" f"has already been set and should not be reset." ) raise ValueError(msg) elif var_name not in self[comp_name].keys(): msg = ( f"Component '{comp_name}' has no variable '{var_name}'" ) raise KeyError(msg) else: self[comp_name][var_name] = var_data
[docs] @classmethod def from_storage_unit(cls, storage_unit, **kwargs): mass_budget_terms = storage_unit.get_mass_budget_terms() for component in mass_budget_terms.keys(): kwargs[component] = {} for var in mass_budget_terms[component]: kwargs[component][var] = storage_unit[var] return Budget(storage_unit.control, **kwargs)
[docs] @staticmethod def get_meta_keys(): """Return a tuple of the metadata keys used by Budget.""" return ("desc", "modules", "var_category", "units")
[docs] @staticmethod def get_components(): return ("inputs", "outputs", "storage_changes")
@property def components(self): return self.get_components() @property def inputs_sum(self): return self._inputs_sum @property def outputs_sum(self): return self._outputs_sum @property def storage_changes_sum(self): return self._storage_changes_sum @property def terms(self): return {comp: list(self[comp].keys()) for comp in self.components}
[docs] def set_initial_accumulations(self, init_accumulations, accum_start_time): self._itime_accumulated = self._itime_step # -1 self._time_accumulated = self._time # None self._accumulations = {} self._accumulations_sum = {} # init to zero self.reset_accumulations() if init_accumulations is None: self._accum_start_time = self.control.init_time return self._accum_start_time = accum_start_time for component in self.components: if component not in init_accumulations.keys(): continue for var in self[component].keys(): if var in init_accumulations[component].keys(): self._accumulations[component][var] = init_accumulations[ component ][var] self._sum_component_accumulations() return
[docs] def reset_accumulations(self): self._accum_start_time = self._time_accumulated for component in self.components: self._accumulations[component] = {} for var in self[component].keys(): self._accumulations[component][var] = zero self._sum_component_accumulations() return
[docs] def advance(self): """Advance time (taken from storageUnit)""" if self._itime_step >= self.control.itime_step: if self.verbose: msg = ( f"{self.name} did not advance because it is " f"not behind control time" ) print(msg) return self._itime_step = self.control.itime_step self._time = self.control.current_time
[docs] def calculate(self): """Accumulate for the timestep.""" if self._itime_accumulated >= self._itime_step: raise ValueError("Can not accumulate twice per timestep") self._inputs_sum = self._sum_inputs() self._outputs_sum = self._sum_outputs() self._storage_changes_sum = self._sum_storage_changes() # accumulate for component in self.components: for var in self[component].keys(): self._accumulations[component][var] += self[component][ var ] * self.control.time_step.astype( f"timedelta64[{self.time_unit}]" ).astype(int) self._sum_component_accumulations() # check balance if self.basis == "unit": self._balance = self._calc_unit_balance() elif self.basis == "global": self._balance = self._calc_global_balance() self._itime_accumulated = self._itime_step self._time_accumulated = self._time return
def _sum_component_accumulations(self): # sum the individual component accumulations for component in self.components: self._accumulations_sum[component] = None for var in self[component].keys(): if self._accumulations_sum[component] is None: if self.basis == "unit": self._accumulations_sum[component] = ( self._accumulations[component][var].copy() ) elif self.basis == "global": self._accumulations_sum[component] = ( self._accumulations[component][var].copy().sum() ) else: if self.basis == "unit": self._accumulations_sum[component] += ( self._accumulations[component][var] ) elif self.basis == "global": self._accumulations_sum[component] += ( self._accumulations[component][var].sum() ) return @property def accumulations(self): return self._accumulations def _sum(self, attr): """Sum over the individual terms in a budget component.""" if self.basis == "unit": vals = [val for val in self[attr].values()] the_sum = sum(vals) elif self.basis == "global": # in global case, the variable dims dont need to match, collapse # to a scalar vals = [sum(val) for val in self[attr].values()] the_sum = sum(vals) else: raise ValueError(f"self.basis '{self.basis}' is invalid") return the_sum def _sum_inputs(self): return self._sum("inputs") def _sum_outputs(self): return self._sum("outputs") def _sum_storage_changes(self): return self._sum("storage_changes") def _calc_unit_balance(self): unit_balance = self._inputs_sum - self._outputs_sum self._zero_sum = True # compare i ?=? o + ds so that relative errors are not compared to # zero when ds is zero if not np.allclose( self._inputs_sum, self._outputs_sum + self._storage_changes_sum, rtol=self.rtol, atol=self.atol, equal_nan=self._ignore_nans, ): self._zero_sum = False lhs = self._inputs_sum rhs = self._outputs_sum + self._storage_changes_sum if self._ignore_nans: actual_nan = np.where(np.isnan(lhs), True, False) desired_nan = np.where(np.isnan(rhs), True, False) msg = "The nan values are not the same for the two arrays" assert (actual_nan == desired_nan).all(), msg abs_diff = abs(lhs - rhs) with np.errstate(divide="ignore", invalid="ignore"): rel_abs_diff = abs_diff / rhs abs_close = abs_diff < self.atol rel_close = rel_abs_diff < self.rtol rel_close = np.where(np.isnan(rel_close), False, rel_close) close = abs_close | rel_close if self._ignore_nans: close = np.where(np.isnan(abs_diff), True, close) wh_not_close = np.where(~close) msg = ( "The flux unit balance not equal to the change in unit " f"storage at time {self.control.current_time} and at the " f"following locations for {self.description}: {wh_not_close}" ) if self.imbalance_fatal: raise ValueError(msg) else: warn(msg, UserWarning) # << return unit_balance def _calc_global_balance(self): global_balance = self._inputs_sum - self._outputs_sum self._zero_sum = True # compare i ?=? o + ds so that relative errors are not compared to # zero when ds is zero if not np.allclose( self._inputs_sum, self._outputs_sum + self._storage_changes_sum, rtol=self.rtol, atol=self.atol, ): self._zero_sum = False msg = ( "The global flux balance not equal to the change in global " f"storage: {self.description}" ) if self.verbose: aerr = self._inputs_sum - ( self._outputs_sum + self._storage_changes_sum ) rerr = aerr / self._inputs_sum msg += f"\n{self.control.current_time}: {aerr=}, {rerr=}" if self.imbalance_fatal: raise ValueError(msg) else: warn(msg, UserWarning) return global_balance @property def balance(self): return self._balance def __repr__(self): if self._itime_step == -1: msg = ( f"Budget (units: {self.units}) of {self.description} " f"only initialized. Check back later." ) return msg n_in = len(self.inputs) n_out = len(self.outputs) n_stor = len(self.storage_changes) n_report = max(n_in, n_out, n_stor) in_keys = list(self.inputs.keys()) out_keys = list(self.outputs.keys()) stor_keys = list(self.storage_changes.keys()) # the following do not copy the ndarrays in_vals = list(self.inputs.values()) out_vals = list(self.outputs.values()) stor_vals = list(self.storage_changes.values()) acc_in_vals = list(self._accumulations["inputs"].values()) acc_out_vals = list(self._accumulations["outputs"].values()) acc_stor_vals = list(self._accumulations["storage_changes"].values()) # ': ' + value spaces col_extra_colon = 2 col_extra_vals = 10 col_extra = col_extra_colon + col_extra_vals in_col_key_width = max([len(kk) for kk in in_keys]) out_col_key_width = max([len(kk) for kk in out_keys]) stor_col_key_width = max([len(kk) for kk in stor_keys]) in_col_width = in_col_key_width + col_extra out_col_width = out_col_key_width + col_extra stor_col_width = stor_col_key_width + col_extra indent_width = 9 indent_fill = " " * indent_width in_col_fill = " " * in_col_width out_col_fill = " " * out_col_width stor_col_fill = " " * stor_col_width sep_width = 4 col_sep = " " * sep_width terms_width = ( in_col_width + out_col_width + stor_col_width + (2 * sep_width) ) total_width = indent_width + terms_width # volume/mass or energy # budget name summary = [] summary += ["*-" * int((total_width) / 2)] if self.basis == "unit": summary += [ f"Individual spatial unit budget for {self.description} " f"(units: {self.units}).\n" "Budget is checked on each spatial unit. This is summary shows" "\nspatial sums for the entire model domain.\n" ] elif self.basis == "global": summary += [ f"Global budget of {self.description} (units: {self.units}).\n" "Budget is checked on full domain: spatially summed fluxes\n" "and storage changes are checked for balance." ] # print the model time. this is summary += [f"@ time: {self._time} (itime_step: {self._itime_step})"] # Timestep rates summary += [""] # header summary += ["This timestep, rates:"] summary += [ indent_fill + "input rates".ljust(in_col_width) + col_sep + "output rates".ljust(out_col_width) + col_sep + "storage change rates".ljust(stor_col_width) ] separator = [ (indent_fill + "-" * in_col_width + col_sep) + ("-" * out_col_width + col_sep) + ("-" * stor_col_width) ] summary += separator # terms line by line with cols: in, out, storage term_data = ( (n_in, in_keys, in_vals, in_col_width, in_col_fill), (n_out, out_keys, out_vals, out_col_width, out_col_fill), (n_stor, stor_keys, stor_vals, stor_col_width, stor_col_fill), ) def table_terms_col_wise(): "Fill terms table column wise" table = [] with np.printoptions(precision=2): for ll in range(n_report): line = indent_fill for n_items, keys, vals, col_width, col_fill in term_data: if ll <= (n_items - 1): nk = len(keys[ll]) # subtract ": ", "m." and "e+ee" prec_width = col_width - nk - 2 - 2 - 4 vals_sum = vals[ll].sum() if vals_sum < 0: prec_width -= 1 vv = np.format_float_scientific( vals_sum, precision=prec_width, ) line += ( pretty_print(f"{keys[ll]}: {vv}", col_width) + col_sep ) else: line += col_fill + col_sep line += col_sep table += [line] return table summary += table_terms_col_wise() # balance line summary += [indent_fill + "-" * terms_width] eq_op = "=" if not self._zero_sum: eq_op = "!=!" term_data = ( ("", self._inputs_sum, in_col_width, in_col_key_width), ("-", self._outputs_sum, out_col_width, out_col_key_width), ( eq_op, self._storage_changes_sum, stor_col_width, stor_col_key_width, ), ) def balance_line_col_wise(): # TODO(JLM): This is a hack until i have some time to sort this out bal_line = "Balance: " for oper, vals_sum, col_width, col_key_width in term_data: if vals_sum.sum() > 0: sign_extra = 0 else: sign_extra = 1 bal_line += ( oper + (" " * (col_key_width + col_extra_colon - len(oper))) + pretty_print( np.format_float_scientific( vals_sum.sum(), precision=col_width - col_key_width - col_extra_colon - 6 - sign_extra, ), col_width - col_key_width - col_extra_colon, ) + col_sep ) return bal_line summary += [balance_line_col_wise()] # Accumulated volumes summary += [""] # header summary += [ f"Accumulated {self._unit_desc} (since {self._accum_start_time}):" ] summary += [ indent_fill + f"input {self._unit_desc}".ljust(in_col_width) + col_sep + f"output {self._unit_desc}".ljust(out_col_width) + col_sep + f"storage change {self._unit_desc}".ljust(out_col_width) ] # separator summary += separator # terms line by line with cols: in, out, storage term_data = ( (n_in, in_keys, acc_in_vals, in_col_width, in_col_fill), (n_out, out_keys, acc_out_vals, out_col_width, out_col_fill), (n_stor, stor_keys, acc_stor_vals, stor_col_width, stor_col_fill), ) # fill terms column wise (abuse scope a bit) summary += table_terms_col_wise() # balance line summary += [indent_fill + "-" * terms_width] term_data = ( ( "", self._accumulations_sum["inputs"], in_col_width, in_col_key_width, ), ( "-", self._accumulations_sum["outputs"], out_col_width, out_col_key_width, ), ( eq_op, self._accumulations_sum["storage_changes"], stor_col_width, stor_col_key_width, ), ) summary += [balance_line_col_wise()] # conclude summary += [""] return "\n".join(summary)
[docs] def output(self) -> None: """Output to previously initialized output types. Returns: None """ if self._output_netcdf: self.__output_netcdf() return
[docs] def initialize_netcdf( self, params: Parameters, output_dir: str, extra_coords: dict = None, write_sum_vars: Union[list, bool] = True, write_individual_vars: bool = False, ) -> None: """Initialize NetCDF output Args: output_dir: directory for NetCDF file Returns: None """ self._output_netcdf = True if extra_coords is None: extra_coords = {} # make working directory output_dir = pl.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) nc_path = pl.Path(output_dir) / f"{self.description}_budget.nc" # Construct a dictionary of {term: var}. If the variables are not # in a term their term is None if write_sum_vars is True: nc_out_vars = list(self.output_vars_desc.keys()) elif isinstance(write_sum_vars, list): nc_out_vars = [ var for var in self.output_vars_desc.keys() if var in write_sum_vars ] elif (write_sum_vars is False) or (write_sum_vars is None): nc_out_vars = [] else: msg = "Unexpected value of write_sum_vars: {write_sum_vars}" raise ValueError(msg) self._netcdf_output_var_dict = {} if len(nc_out_vars): self._netcdf_output_var_dict = {None: nc_out_vars} if write_individual_vars: self._netcdf_output_var_dict = { None: nc_out_vars, **self.terms, } if len(self._netcdf_output_var_dict) == 0: msg = ( f"Budget for {self.description} has no requested output, " "setting self._output_netcdf = False" ) warn(msg) self._output_netcdf = False return global_attrs = { "Description": ( f"pywatershed ({self.basis}) budget for {self.description}" ), "Budget basis": f"{self.basis} (unit or global)", } for key in self.terms.keys(): global_attrs[key] = "[" + ", ".join(self.terms[key]) + "]" if self.basis == "unit": coordinates = params.coords meta = self.meta else: coordinates = {"one": 0} meta = deepcopy(self.meta) for kk, vv in meta.items(): meta[kk]["dims"] = ("one",) self._netcdf = NetCdfWrite( nc_path, coordinates, self._netcdf_output_var_dict, meta, extra_coords=extra_coords, global_attrs=global_attrs, ) # todo jlm: put terms in to metadata return
def __output_netcdf(self) -> None: """Output variable data for a time step Returns: None """ if self._output_netcdf: self._netcdf.time[self.control.itime_step] = nc4.date2num( self.control.current_datetime, self._netcdf.time.units ) for nc_group, group_vars in self._netcdf_output_var_dict.items(): for nc_var in group_vars: var_self_name = nc_var if nc_group is None: var_path = nc_var self._netcdf.dataset[var_path][ self.control.itime_step, : ] = self[var_self_name] else: var_path = f"{nc_group}/{nc_var}" self._netcdf.dataset[var_path][ self.control.itime_step, : ] = self[nc_group][var_self_name] return def _finalize_netcdf(self) -> None: if self._output_netcdf: self._netcdf.close() return