Source code for pywatershed.base.budget

import pathlib as pl
from copy import deepcopy
from typing import Literal, Union
from warnings import warn

import netCDF4 as nc4
import numpy as np

from pywatershed.base.control import Control

from ..constants import zero
from ..utils.netcdf_utils import NetCdfWrite
from .accessor import Accessor
from .parameters import Parameters

# Todo
# * documentation
# * get time_units / time_step from control
# * decide on the fate of "from storage unit"
# * terminology: "balance" just means "storage change"
# * terminology: what is the term for the budget dosent close/zero?


[docs] class Budget(Accessor): """Budget class for mass and energy conservation. Currently no energy budget has been implmenented, todo. """
[docs] def __init__( self, control: Control, inputs: Union[list, dict], outputs: Union[list, dict], storage_changes: Union[list, dict], exchanges: Union[list, dict] = None, init_accumulations: dict = None, accum_start_time: np.datetime64 = None, units: dict = None, time_unit: str = "D", # TODO: get this from control description: str = None, rtol: float = 1e-5, atol: float = 1e-5, basis: Literal["unit", "global"] = "unit", imbalance_fatal: bool = False, ignore_nans: bool = False, unit_desc: str = "", verbose: bool = True, ): self.name = "Budget" self.control = control self.inputs = self.init_component(inputs) self.outputs = self.init_component(outputs) self.storage_changes = self.init_component(storage_changes) # Handle optional exchanges component (for bi-directional fluxes) if exchanges is None: self.exchanges = {} self._has_exchanges = False else: self.exchanges = self.init_component(exchanges) self._has_exchanges = True self.units = units self.time_unit = time_unit self.description = description self.rtol = rtol self.atol = atol self.imbalance_fatal = imbalance_fatal self._ignore_nans = ignore_nans self._unit_desc = unit_desc if self._unit_desc != "": self._unit_desc = f" ({self._unit_desc})" else: self._unit_desc = " " self.verbose = verbose self.basis = basis self._output_netcdf = False self._inputs_sum = None self._outputs_sum = None self._storage_changes_sum = None self._exchanges_sum = None self._balance = None self._accumulations = None self._accumulations_sum = None self._zero_sum = None self._time = self.control.current_time self._itime_step = self.control.itime_step # metadata all_vars = [list(self[cc].keys()) for cc in self.get_components()] all_vars = [x for xs in all_vars for x in xs] self.meta = self.control.meta.get_vars(all_vars) if len(self.meta) == len(all_vars): if self.units is None: all_units = [val["units"] for val in self.meta.values()] else: all_units = list(self.units.values()) # check consistent units if not (np.array(all_units) == all_units[0]).all(): msg = "Units not consistent over all terms" raise ValueError(msg) self.units = all_units[0] else: all_vars_0 = [vv[0] == "_" for vv in all_vars] all_vars_private = all(all_vars_0) if not all_vars_private or verbose: msg = ( f"Metadata unavailable for some Budget terms in {all_vars}" ) warn(msg) self.units = None # generate metadata for derived output variables self.output_vars_desc = { "inputs_sum": f"Sum of input fluxes ({self.basis})", "outputs_sum": f"Sum of output fluxes ({self.basis})", "storage_changes_sum": f"Sum of storage changes ({self.basis})", # "balance": # f"Balance of fluxes and storage changes ({self.basis})", } sum_keys = list(self.output_vars_desc.keys()) for kk in sum_keys: terms = self.terms[kk[0:-4]] if terms is None or not len(terms): del self.output_vars_desc[kk] for kk, desc in self.output_vars_desc.items(): terms = self.terms[kk[0:-4]] term = terms[0] if term in self.meta.keys(): term_meta = self.meta[term] self.meta[kk] = deepcopy(term_meta) self.meta[kk]["desc"] = desc self.meta[kk]["units"] = self.units elif verbose: msg = f"budget term {term} not available in metadata" warn(msg) # if (var == "balance") and (self.basis == "global"): # self.meta[var]["dimensions"] = {0: "one"} self.set_initial_accumulations(init_accumulations, accum_start_time) return
[docs] @staticmethod def init_component(data: Union[list, dict]) -> dict: if isinstance(data, dict): return data else: return {dd: None for dd in data}
[docs] def set(self, data: dict): """Set the data on the components after initialization Args: data: a dict of dicts with top level optional keys: [inputs, outputs, storage_changes]. Each of those is a dict with var: np.ndarray, eg. data = {'inputs': {'var': np.ndarray}} """ for comp_name, comp_dict in data.items(): for var_name, var_data in comp_dict.items(): if self[comp_name][var_name] is not None: msg = ( f"Component '{comp_name}' variable '{var_name}'" f"has already been set and should not be reset." ) raise ValueError(msg) elif var_name not in self[comp_name].keys(): msg = ( f"Component '{comp_name}' has no variable '{var_name}'" ) raise KeyError(msg) else: self[comp_name][var_name] = var_data
[docs] @classmethod def from_storage_unit(cls, storage_unit, quantity="mass", **kwargs): # Get budget terms based on quantity if quantity == "mass": budget_terms = storage_unit.get_mass_budget_terms() elif quantity == "energy": budget_terms = storage_unit.get_energy_budget_terms() else: raise ValueError(f"Unknown quantity: {quantity}") for component in budget_terms.keys(): kwargs[component] = {} for var in budget_terms[component]: kwargs[component][var] = storage_unit[var] return Budget(storage_unit.control, **kwargs)
[docs] @staticmethod def get_meta_keys(): """Return a tuple of the metadata keys used by Budget.""" return ("desc", "modules", "var_category", "units")
[docs] @staticmethod def get_components(has_exchanges=False): if has_exchanges: return ("inputs", "exchanges", "outputs", "storage_changes") return ("inputs", "outputs", "storage_changes")
@property def components(self): return self.get_components(self._has_exchanges) @property def inputs_sum(self): return self._inputs_sum @property def outputs_sum(self): return self._outputs_sum @property def storage_changes_sum(self): return self._storage_changes_sum @property def exchanges_sum(self): return self._exchanges_sum @property def terms(self): return {comp: list(self[comp].keys()) for comp in self.components}
[docs] def set_initial_accumulations(self, init_accumulations, accum_start_time): self._itime_accumulated = self._itime_step # -1 self._time_accumulated = self._time # None self._accumulations = {} self._accumulations_sum = {} # init to zero self.reset_accumulations() if init_accumulations is None: self._accum_start_time = self.control.init_time return self._accum_start_time = accum_start_time for component in self.components: if component not in init_accumulations.keys(): continue for var in self[component].keys(): if var in init_accumulations[component].keys(): self._accumulations[component][var] = init_accumulations[ component ][var] self._sum_component_accumulations() return
[docs] def reset_accumulations(self): self._accum_start_time = self._time_accumulated for component in self.components: self._accumulations[component] = {} for var in self[component].keys(): self._accumulations[component][var] = zero self._sum_component_accumulations() return
[docs] def advance(self): """Advance time (taken from storageUnit)""" if self._itime_step >= self.control.itime_step: if self.verbose: msg = ( f"{self.name} did not advance because it is " f"not behind control time" ) print(msg) return self._itime_step = self.control.itime_step self._time = self.control.current_time
[docs] def calculate(self): """Accumulate for the timestep.""" if self._itime_accumulated >= self._itime_step: raise ValueError("Can not accumulate twice per timestep") self._inputs_sum = self._sum_inputs() self._outputs_sum = self._sum_outputs() self._storage_changes_sum = self._sum_storage_changes() if self._has_exchanges: self._exchanges_sum = self._sum_exchanges() # accumulate for component in self.components: for var in self[component].keys(): self._accumulations[component][var] += self[component][ var ] * self.control.time_step.astype( f"timedelta64[{self.time_unit}]" ).astype(int) self._sum_component_accumulations() # check balance if self.basis == "unit": self._balance = self._calc_unit_balance() elif self.basis == "global": self._balance = self._calc_global_balance() self._itime_accumulated = self._itime_step self._time_accumulated = self._time return
def _sum_component_accumulations(self): # sum the individual component accumulations for component in self.components: self._accumulations_sum[component] = None for var in self[component].keys(): if self._accumulations_sum[component] is None: if self.basis == "unit": self._accumulations_sum[component] = ( self._accumulations[component][var].copy() ) elif self.basis == "global": self._accumulations_sum[component] = ( self._accumulations[component][var].copy().sum() ) else: if self.basis == "unit": self._accumulations_sum[component] += ( self._accumulations[component][var] ) elif self.basis == "global": self._accumulations_sum[component] += ( self._accumulations[component][var].sum() ) return @property def accumulations(self): return self._accumulations def _sum(self, attr): """Sum over the individual terms in a budget component.""" if self.basis == "unit": vals = [val for val in self[attr].values()] the_sum = sum(vals) elif self.basis == "global": # in global case, the variable dims dont need to match, collapse # to a scalar vals = [sum(val) for val in self[attr].values()] the_sum = sum(vals) else: raise ValueError(f"self.basis '{self.basis}' is invalid") return the_sum def _sum_inputs(self): return self._sum("inputs") def _sum_outputs(self): return self._sum("outputs") def _sum_storage_changes(self): return self._sum("storage_changes") def _sum_exchanges(self): return self._sum("exchanges") def _calc_unit_balance(self): self._zero_sum = True # compare # lhs ?=? rhs # i + e ?=? o + ds (e is optional term) # so that relative errors are not compared to rhs = self._outputs_sum + self._storage_changes_sum # LHS depends on if we have exchanges or not if self._has_exchanges: unit_balance = ( self._inputs_sum + self._exchanges_sum - self._outputs_sum ) lhs = self._inputs_sum + self._exchanges_sum else: unit_balance = self._inputs_sum - self._outputs_sum lhs = self._inputs_sum # zero when ds is zero if not np.allclose( lhs, rhs, rtol=self.rtol, atol=self.atol, equal_nan=self._ignore_nans, ): self._zero_sum = False if self._ignore_nans: actual_nan = np.where(np.isnan(lhs), True, False) desired_nan = np.where(np.isnan(rhs), True, False) msg = "The nan values are not the same for the two arrays" assert (actual_nan == desired_nan).all(), msg abs_diff = abs(lhs - rhs) with np.errstate(divide="ignore", invalid="ignore"): rel_abs_diff = abs(abs_diff / rhs) abs_close = abs_diff < self.atol rel_close = rel_abs_diff < self.rtol rel_close = np.where(np.isnan(rel_close), False, rel_close) close = abs_close | rel_close if self._ignore_nans: close = np.where(np.isnan(abs_diff), True, close) wh_not_close = np.where(~close) msg = ( "The flux unit balance not equal to the change in unit " f"storage at time {self.control.current_time} and at the " f"following locations for {self.description}: {wh_not_close}" ) if self.imbalance_fatal: raise ValueError(msg) else: warn(msg, UserWarning) # << return unit_balance def _calc_global_balance(self): global_balance = self._inputs_sum - self._outputs_sum self._zero_sum = True # compare i ?=? o + ds so that relative errors are not compared to # zero when ds is zero if not np.allclose( self._inputs_sum, self._outputs_sum + self._storage_changes_sum, rtol=self.rtol, atol=self.atol, ): self._zero_sum = False msg = ( "The global flux balance not equal to the change in global " f"storage: {self.description}" ) if self.verbose: aerr = self._inputs_sum - ( self._outputs_sum + self._storage_changes_sum ) rerr = aerr / self._inputs_sum msg += f"\n{self.control.current_time}: {aerr=}, {rerr=}" if self.imbalance_fatal: raise ValueError(msg) else: warn(msg, UserWarning) return global_balance @property def balance(self): return self._balance def __repr__(self): """Budget string representation""" if self._itime_step == -1: msg = ( f"Budget (units: {self.units}) of {self.description} " f"only initialized. Check back later." ) return msg # Determine which components to display components_to_display = [] if len(self.inputs) > 0: components_to_display.append(("inputs", self.inputs)) if self._has_exchanges and len(self.exchanges) > 0: components_to_display.append(("exchanges", self.exchanges)) if len(self.outputs) > 0: components_to_display.append(("outputs", self.outputs)) if len(self.storage_changes) > 0: components_to_display.append( ("storage_changes", self.storage_changes) ) # Calculate column widths col_widths = {} col_extra = 12 # ': ' + scientific notation space for comp_name, comp_dict in components_to_display: max_key_len = ( max([len(k) for k in comp_dict.keys()]) if comp_dict else 5 ) col_widths[comp_name] = max_key_len + col_extra indent = " " * 9 col_sep = " " total_width = ( sum(col_widths.values()) + len(col_sep) * (len(components_to_display) - 1) + 9 ) # Build output summary = ["*-" * int(total_width / 2)] # Header if self.basis == "unit": summary += [ f"Individual spatial unit budget for {self.description} " f"(units: {self.units}).\n" "Budget is checked on each spatial unit. This is summary shows" "\nspatial sums for the entire model domain.\n" ] elif self.basis == "global": summary += [ f"Global budget of {self.description} (units: {self.units}).\n" "Budget is checked on full domain: spatially summed fluxes\n" "and storage changes are checked for balance." ] summary += [f"@ time: {self._time} (itime_step: {self._itime_step})"] summary += ["", "This timestep:"] # Column headers header_line = indent for comp_name, _ in components_to_display: display_name = comp_name.replace("_", " ") header_line += display_name.ljust(col_widths[comp_name]) + col_sep summary += [header_line] # Separator sep_line = indent for comp_name, _ in components_to_display: sep_line += "-" * col_widths[comp_name] + col_sep summary += [sep_line] # Data rows max_rows = max( [len(comp_dict) for _, comp_dict in components_to_display] ) for row_idx in range(max_rows): line = indent for comp_name, comp_dict in components_to_display: keys = list(comp_dict.keys()) vals = list(comp_dict.values()) if row_idx < len(keys): key = keys[row_idx] val_sum = vals[row_idx].sum() # Format value precision = col_widths[comp_name] - len(key) - 8 if val_sum < 0: precision -= 1 val_str = np.format_float_scientific( val_sum, precision=precision ) line += ( f"{key}: {val_str}".ljust(col_widths[comp_name]) + col_sep ) else: line += " " * col_widths[comp_name] + col_sep summary += [line] # Balance line summary += [ indent + "-" * ( sum(col_widths.values()) + len(col_sep) * (len(components_to_display) - 1) ) ] eq_op = "=" if self._zero_sum else "!=!" balance_line = "Balance: " # Build balance equation based on presence of exchanges balance_parts = [] for comp_name, _ in components_to_display: comp_sum = getattr(self, f"_{comp_name}_sum") if comp_sum is None: comp_sum = np.float64(0.0) total = comp_sum.sum() if hasattr(comp_sum, "sum") else comp_sum # Determine operator if comp_name == "inputs": op = "" elif comp_name == "exchanges": op = "+" elif comp_name == "outputs": op = "-" elif comp_name == "storage_changes": # For exchanges: storage_changes gets subtracted like outputs # For no exchanges: storage_changes gets the eq_op (= or !=!) op = "-" if self._has_exchanges else eq_op else: op = "+" # Format value key_width = max( [len(k) for k in list(components_to_display[0][1].keys())] ) precision = col_widths[comp_name] - key_width - 8 if total < 0: precision -= 1 val_str = np.format_float_scientific( total, precision=max(1, precision) ) balance_parts.append( f"{op} ".ljust(key_width + 2) + val_str.rjust(10) ) # Calculate residual if self._has_exchanges: residual = ( ( self._inputs_sum.sum() if hasattr(self._inputs_sum, "sum") else self._inputs_sum ) + ( self._exchanges_sum.sum() if hasattr(self._exchanges_sum, "sum") else self._exchanges_sum ) - ( self._outputs_sum.sum() if hasattr(self._outputs_sum, "sum") else self._outputs_sum ) - ( self._storage_changes_sum.sum() if hasattr(self._storage_changes_sum, "sum") else ( self._storage_changes_sum if self._storage_changes_sum is not None else 0.0 ) ) ) else: residual = ( ( self._inputs_sum.sum() if hasattr(self._inputs_sum, "sum") else self._inputs_sum ) - ( self._outputs_sum.sum() if hasattr(self._outputs_sum, "sum") else self._outputs_sum ) - ( self._storage_changes_sum.sum() if hasattr(self._storage_changes_sum, "sum") else ( self._storage_changes_sum if self._storage_changes_sum is not None else 0.0 ) ) ) # Only show residual column if there are exchanges if self._has_exchanges: residual_str = np.format_float_scientific(residual, precision=1) residual_op = "=" if self._zero_sum else "!=!" summary += [ balance_line + col_sep.join(balance_parts) + col_sep + f"{residual_op} {residual_str}".rjust(12) ] else: summary += [balance_line + col_sep.join(balance_parts)] # Accumulations summary += ["", f"Accumulations (since {self._accum_start_time}):"] # Accumulation header header_line = indent for comp_name, _ in components_to_display: display_name = comp_name.replace("_", " ") header_line += display_name.ljust(col_widths[comp_name]) + col_sep summary += [header_line, sep_line] # Accumulation data for row_idx in range(max_rows): line = indent for comp_name, comp_dict in components_to_display: keys = list(comp_dict.keys()) if row_idx < len(keys): key = keys[row_idx] acc_val_sum = self._accumulations[comp_name][key].sum() precision = col_widths[comp_name] - len(key) - 8 if acc_val_sum < 0: precision -= 1 val_str = np.format_float_scientific( acc_val_sum, precision=precision ) line += ( f"{key}: {val_str}".ljust(col_widths[comp_name]) + col_sep ) else: line += " " * col_widths[comp_name] + col_sep summary += [line] # Accumulation balance summary += [ indent + "-" * ( sum(col_widths.values()) + len(col_sep) * (len(components_to_display) - 1) ) ] balance_line = "Balance: " balance_parts = [] for comp_name, _ in components_to_display: acc_sum = self._accumulations_sum[comp_name] if acc_sum is None: acc_sum = np.float64(0.0) total = acc_sum.sum() if hasattr(acc_sum, "sum") else acc_sum if comp_name == "inputs": op = "" elif comp_name == "exchanges": op = "+" elif comp_name == "outputs": op = "-" elif comp_name == "storage_changes": op = "-" if self._has_exchanges else eq_op else: op = "+" key_width = max( [len(k) for k in list(components_to_display[0][1].keys())] ) precision = col_widths[comp_name] - key_width - 8 if total < 0: precision -= 1 val_str = np.format_float_scientific( total, precision=max(1, precision) ) balance_parts.append( f"{op} ".ljust(key_width + 2) + val_str.rjust(10) ) # Calculate accumulated residual if self._has_exchanges: acc_residual = ( ( self._accumulations_sum["inputs"].sum() if hasattr(self._accumulations_sum["inputs"], "sum") else self._accumulations_sum["inputs"] ) + ( self._accumulations_sum["exchanges"].sum() if hasattr(self._accumulations_sum["exchanges"], "sum") else self._accumulations_sum["exchanges"] ) - ( self._accumulations_sum["outputs"].sum() if hasattr(self._accumulations_sum["outputs"], "sum") else self._accumulations_sum["outputs"] ) - ( self._accumulations_sum["storage_changes"].sum() if hasattr( self._accumulations_sum["storage_changes"], "sum" ) else ( self._accumulations_sum["storage_changes"] if self._accumulations_sum["storage_changes"] is not None else 0.0 ) ) ) else: acc_residual = ( ( self._accumulations_sum["inputs"].sum() if hasattr(self._accumulations_sum["inputs"], "sum") else self._accumulations_sum["inputs"] ) - ( self._accumulations_sum["outputs"].sum() if hasattr(self._accumulations_sum["outputs"], "sum") else self._accumulations_sum["outputs"] ) - ( self._accumulations_sum["storage_changes"].sum() if hasattr( self._accumulations_sum["storage_changes"], "sum" ) else ( self._accumulations_sum["storage_changes"] if self._accumulations_sum["storage_changes"] is not None else 0.0 ) ) ) # Only show residual column if there are exchanges if self._has_exchanges: acc_residual_str = np.format_float_scientific( acc_residual, precision=1 ) summary += [ balance_line + col_sep.join(balance_parts) + col_sep + f"{residual_op} {acc_residual_str}".rjust(12), "", ] else: summary += [balance_line + col_sep.join(balance_parts), ""] return "\n".join(summary)
[docs] def output(self) -> None: """Output to previously initialized output types. Returns: None """ if self._output_netcdf: self.__output_netcdf() return
[docs] def initialize_netcdf( self, params: Parameters, output_dir: str, extra_coords: dict = None, write_sum_vars: Union[list, bool] = True, write_individual_vars: bool = False, ) -> None: """Initialize NetCDF output Args: output_dir: directory for NetCDF file Returns: None """ self._output_netcdf = True if extra_coords is None: extra_coords = {} # make working directory output_dir = pl.Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) nc_path = pl.Path(output_dir) / f"{self.description}_budget.nc" # Construct a dictionary of {term: var}. If the variables are not # in a term their term is None if write_sum_vars is True: nc_out_vars = list(self.output_vars_desc.keys()) elif isinstance(write_sum_vars, list): nc_out_vars = [ var for var in self.output_vars_desc.keys() if var in write_sum_vars ] elif (write_sum_vars is False) or (write_sum_vars is None): nc_out_vars = [] else: msg = "Unexpected value of write_sum_vars: {write_sum_vars}" raise ValueError(msg) self._netcdf_output_var_dict = {} if len(nc_out_vars): self._netcdf_output_var_dict = {None: nc_out_vars} if write_individual_vars: self._netcdf_output_var_dict = { None: nc_out_vars, **self.terms, } if len(self._netcdf_output_var_dict) == 0: msg = ( f"Budget for {self.description} has no requested output, " "setting self._output_netcdf = False" ) warn(msg) self._output_netcdf = False return global_attrs = { "Description": ( f"pywatershed ({self.basis}) budget for {self.description}" ), "Budget basis": f"{self.basis} (unit or global)", } for key in self.terms.keys(): global_attrs[key] = "[" + ", ".join(self.terms[key]) + "]" if self.basis == "unit": coordinates = params.coords meta = self.meta else: coordinates = {"one": 0} meta = deepcopy(self.meta) for kk, vv in meta.items(): meta[kk]["dims"] = ("one",) self._netcdf = NetCdfWrite( nc_path, coordinates, self._netcdf_output_var_dict, meta, extra_coords=extra_coords, global_attrs=global_attrs, ) # todo jlm: put terms in to metadata return
def __output_netcdf(self) -> None: """Output variable data for a time step Returns: None """ if self._output_netcdf: self._netcdf.time[self.control.itime_step] = nc4.date2num( self.control.current_datetime, self._netcdf.time.units ) for nc_group, group_vars in self._netcdf_output_var_dict.items(): for nc_var in group_vars: var_self_name = nc_var if nc_group is None: var_path = nc_var self._netcdf.dataset[var_path][ self.control.itime_step, : ] = self[var_self_name] else: var_path = f"{nc_group}/{nc_var}" self._netcdf.dataset[var_path][ self.control.itime_step, : ] = self[nc_group][var_self_name] return def _finalize_netcdf(self) -> None: if self._output_netcdf: self._netcdf.close() return