"""Output collection and statistical analysis for pywatershed models.
Supports PRMS processes and FlowGraph. Collects three types of output:
1. Monthly accumulations - All spatial units (HRUs, segments, nodes)
2. NOI (Nodes of Interest) - Time series and stats at specific network nodes
3. HOI (HRUs of Interest) - Time series and stats for specific HRUs
Nomenclature
------------
This module adopts NOI (Nodes of Interest) and HOI (HRUs of Interest) to
denote subsets of the model's spatial discretization, distinguishing them
from NOI (Points of Interest) used in PRMS to denote real-world locations
like gage stations. While NOIs represent physical monitoring locations,
NOIs and HOIs represent subsets of the model grid for focused output
collection and analysis.
- NOI: Nodes (river segments in PRMSChannel, segments + features in FlowGraph)
- HOI: HRUs (Hydrologic Response Units)
- NOI: Real-world monitoring locations (PRMS parameter noi_gage_segment)
Example
-------
>>> import pywatershed as pws
>>> from pywatershed.analysis.time_stats import mean
>>>
>>> def max_flow(da):
... return da.max(dim="time")
...
>>>
>>> output = pws.base.Output(
... control=control,
... model=model,
... monthly_accum_var_list=["sroff", "hru_actet"],
... noi_var_list=["seg_outflow"],
... noi_ids=[12345, 67890],
... noi_stats={mean: ["seg_outflow"], max_flow: ["seg_outflow"]},
... hoi_var_list=["hru_actet"],
... hoi_ids=[1, 2, 3],
... hoi_stats={mean: ["hru_actet"]},
... )
>>> model.run(finalize=True, output=output)
>>>
>>> # Access hierarchically
>>> output.noi_stats["seg_outflow"]["mean"]
>>> output.hoi_stats["hru_actet"]["mean"]
Notes
-----
- Statistics use dict[function: var_list] pattern
- Results stored hierarchically: output.noi_stats[variable][statistic]
- Each result has metadata: variable, statistic, period_of_record
- IDs can be list (same for all vars) or dict (per-variable)
- Must finalize before accessing statistics
"""
import pathlib as pl
import warnings
from typing import TYPE_CHECKING, Callable, Literal
import numpy as np
import xarray as xr
from ..constants import var_type_to_numpy_type
from . import meta
from .flow_graph import FlowGraph
if TYPE_CHECKING:
from .control import Control
from .model import Model
spatial_dim_to_coord_name = {
"nhru": "nhm_id",
"nsegment": "nhm_seg",
"nnodes": "node_coord",
}
[docs]
class Output:
"""Output collection and statistical analysis for models.
Collects data during execution, computes statistics after finalization.
Supports PRMS processes and FlowGraph.
Parameters
----------
control : Control
Model control with timing information
model : Model
Pywatershed model instance
monthly_accum_var_list : list[str], optional
Variables to accumulate monthly (all spatial units)
noi_var_list : list[str], optional
Variables to collect at nodes of interest. Required if noi_ids is
a list.
Must NOT be provided if dict (use dict keys).
noi_ids : list[int] or dict[str, list[int]] or list[tuple] or \
dict[str, list[tuple]], optional
Node IDs for NOIs. Supports two ID types and two modes.
ID types:
- Simple IDs: Integer nhm_seg values (e.g., [12345, 67890])
- FlowGraph tuples: (node_maker_name, node_maker_id) for FlowGraph
nodes
(e.g., [("prms_channel", 12345), ("starfit", 0)])
Modes:
- List mode: Same IDs for all vars (requires noi_var_list)
- Dict mode: {var_name: [ids]} per-variable IDs (don't provide
noi_var_list)
noi_stats : dict[Callable, list[str]], optional
Statistics for NOIs: {function: [var1, var2, ...]}
hoi_var_list : list[str], optional
Variables to collect for HRUs of interest. Required if hoi_ids is
a list.
Must NOT be provided if hoi_ids is dict (use dict keys).
hoi_ids : list[int] or dict[str, list[int]], optional
HRU IDs (nhm_id values). Two modes:
- List mode: Same IDs for all vars (requires hoi_var_list)
- Dict mode: {var_name: [ids]} per-variable IDs (don't provide
hoi_var_list)
hoi_stats : dict[Callable, list[str]], optional
Statistics for HOIs: {function: [var1, var2, ...]}
chunked_var_list : list[str], optional
Variables to write in chunks to zarr format. These variables will be
buffered in memory for chunk_sizes['time'] timesteps and written to
zarr file when buffer is full.
chunked_output_file : str or pathlib.Path, optional
Path to zarr output file. Required if chunked_var_list is provided.
chunk_sizes : dict[str, int], optional
Chunk sizes for zarr output. Keys: 'time', 'nhru', 'nsegment',
'nnode'. If not provided, sensible defaults will be used based on
spatial dimensions. Defaults target ~10-100 MB chunks.
chunk_size_auto_warn : bool, optional
If True (default), warn when chunk_sizes are auto-determined.
Attributes
----------
time : np.ndarray
Daily time coordinate
time_months : np.ndarray
Monthly time coordinate
n_days_per_month : xr.DataArray
Days per month (for converting accumulations to means)
monthly_accumulations : dict[str, xr.DataArray]
Monthly values, available after finalization
noi_arrays : dict[str, xr.DataArray]
NOI time series, available after finalization
hoi_arrays : dict[str, xr.DataArray]
HOI time series, available after finalization
noi_stats : dict[str, dict[str, xr.DataArray]]
NOI statistics: noi_stats[variable][statistic]
hoi_stats : dict[str, dict[str, xr.DataArray]]
HOI statistics: hoi_stats[variable][statistic]
Examples
--------
>>> from pywatershed.analysis.time_stats import mean
>>> def max_flow(da):
... return da.max(dim="time")
...
>>>
>>> # List mode: same IDs for all variables
>>> output = pws.base.Output(
... control=control,
... model=model,
... noi_var_list=["seg_outflow"], # Required with list mode
... noi_ids=[12345, 67890],
... noi_stats={mean: ["seg_outflow"]},
... hoi_var_list=["hru_actet"], # Required with list mode
... hoi_ids=[1, 2, 3],
... hoi_stats={mean: ["hru_actet"]},
... )
>>>
>>> # Dict mode: per-variable IDs
>>> output = pws.base.Output(
... control=control,
... model=model,
... noi_ids={ # Dict keys define variables
... "seg_outflow": [12345, 67890],
... "seg_upstream_inflow": [12345],
... },
... noi_stats={mean: ["seg_outflow"]}, # Not all vars need stats
... hoi_ids={ # Dict keys define variables
... "hru_actet": [1, 2],
... "pkwater_equiv": [3, 4, 5],
... },
... hoi_stats={mean: ["hru_actet"]},
... )
>>>
>>> # FlowGraph mode: tuple IDs for nodes
>>> output = pws.base.Output(
... control=control,
... model=flowgraph_model,
... noi_ids={
... "node_outflows": [("prms_channel", 12345), ("starfit", 0)],
... "node_storages": [("starfit", 0)],
... },
... noi_stats={mean: ["node_outflows"]},
... )
>>> model.run(finalize=True, output=output)
>>>
>>> # Access hierarchically
>>> output.noi_stats["seg_outflow"]["mean"]
>>> output.hoi_stats["hru_actet"]["mean"]
"""
[docs]
def __init__(
self,
control: "Control",
model: "Model",
monthly_accum_var_list: list | None = None,
noi_var_list: list | None = None,
noi_ids: list | dict | None = None,
noi_stats: dict[Callable, list[str]] | None = None,
hoi_var_list: list | None = None,
hoi_ids: list | dict | None = None,
hoi_stats: dict[Callable, list[str]] | None = None,
netcdf_output_action: Literal["allow", "error"] = "error",
chunked_var_list: list[str] | None = None,
chunked_output_file: str | pl.Path | None = None,
chunk_sizes: dict[str, int] | None = None,
chunk_size_auto_warn: bool = True,
):
"""Initialize Output and set up data collection."""
self._control = control
self._netcdf_output_action = netcdf_output_action
self._take_netcdf_output_action()
self._finalized = False
self._model = model
self._monthly_accum_var_list = monthly_accum_var_list
# Process NOI IDs (can be list or dict)
self._noi_var_list, self._noi_ids = self._process_noi_ids(
noi_var_list, noi_ids
)
self._noi_stats = noi_stats
# Process HOI IDs (can be list or dict)
self._hoi_var_list, self._hoi_ids = self._process_hoi_ids(
hoi_var_list, hoi_ids
)
self._hoi_stats = hoi_stats
self._current_time = self._control.init_time.copy()
self._time_step = self._control.time_step.copy()
self._init_monthly()
self._init_noi_sub()
self._init_zarr_chunked(
chunked_var_list,
chunked_output_file,
chunk_sizes,
chunk_size_auto_warn,
)
self._build_chunked_iteration_list()
return None
def _take_netcdf_output_action(self):
if (
"netcdf_output_dir" in self._control.options.keys()
and self._netcdf_output_action != "allow"
):
msg = (
"control.options['netcdf_output_dir'] is defined in "
"addition to Output object being intitalized with argument "
f"netcdf_output_action={self._netcdf_output_action}. If you "
"truly want both NetCDF and Zarr output, set "
"netcdf_output_action=allow."
)
raise ValueError(msg)
# ==== ID Processing Methods =========================
def _process_noi_ids(
self, var_list, ids
) -> tuple[list | None, list | dict | None]:
"""Process NOI IDs - handle list or dict for per-variable IDs.
Two modes:
- List mode: var_list required, same IDs for all variables
- Dict mode: dict keys define variables, var_list must be None
"""
if var_list is None and ids is None:
return None, None
if ids is None:
raise ValueError(
"noi_ids must be passed when noi variables are requested."
)
# Handle dict mode
if isinstance(ids, dict):
if var_list is not None:
raise ValueError(
"noi_var_list should not be provided when noi_ids is a "
"dict. Use dict keys to specify variables."
)
var_list = list(ids.keys())
return var_list, ids
# List mode - existing behavior
if var_list is None:
raise ValueError("noi_var_list required when noi_ids is a list")
return var_list, ids
def _process_hoi_ids(
self, var_list, ids
) -> tuple[list | None, list | dict | None]:
"""Process HOI IDs - handle list or dict for per-variable IDs.
Two modes:
- List mode: var_list required, same IDs for all variables
- Dict mode: dict keys define variables, var_list must be None
"""
if var_list is None and ids is None:
return None, None
# Handle dict mode
if isinstance(ids, dict):
if var_list is not None:
raise ValueError(
"hoi_var_list should not be provided when hoi_ids is a "
"dict. Use dict keys to specify variables."
)
var_list = list(ids.keys())
return var_list, ids
# List mode - existing behavior
if ids is not None and var_list is None:
raise ValueError("hoi_var_list required when hoi_ids is a list")
return var_list, ids
# ==== Properties =========================
@property
def time(self) -> np.ndarray | None:
"""Daily time coordinate for NOI/HOI data."""
return self._time
@property
def time_months(self) -> np.ndarray | None:
"""Monthly time coordinate for accumulations."""
return self._time_months
@property
def n_days_per_month(self) -> xr.DataArray | None:
"""Number of days in each month for stats.
Returns
-------
xr.DataArray or None
DataArray of day counts per month with month dimension, useful for
converting accumulations to means. Only available after
finalization.
"""
if self._finalized:
return self._n_days_per_month
else:
warnings.warn(
"n_days_per_month is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
@property
def monthly_accumulations(self) -> dict[str, xr.DataArray] | None:
"""Monthly accumulations (available after finalization)."""
if self._finalized:
return self._monthly_arrays
else:
warnings.warn(
"monthly_accumulations is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
@property
def noi_arrays(self) -> dict[str, xr.DataArray] | None:
"""NOI time series (available after finalization)."""
if self._finalized:
return self._noi_arrays
else:
warnings.warn(
"noi_arrays is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
@property
def hoi_arrays(self) -> dict[str, xr.DataArray] | None:
"""HOI time series (available after finalization)."""
if self._finalized:
return self._hoi_arrays
else:
warnings.warn(
"hoi_arrays is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
@property
def noi_stats(self) -> dict[str, dict[str, xr.DataArray]] | None:
"""NOI statistics: noi_stats[variable][statistic] (after
finalization)."""
if self._finalized:
return self._noi_stats_results
else:
warnings.warn(
"noi_stats is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
@property
def hoi_stats(self) -> dict[str, dict[str, xr.DataArray]] | None:
"""HOI statistics: hoi_stats[variable][statistic] (after
finalization)."""
if self._finalized:
return self._hoi_stats_results
else:
warnings.warn(
"hoi_stats is only available after finalization. "
"Call output.finalize() or model.run(finalize=True)."
)
return None
# ==== Validation methods =====================
# ==== Monthly accumulation section =====================
def _init_monthly(self) -> None:
"""Initialize monthly accumulation data structures."""
if self._monthly_accum_var_list is None:
self._time_months = None
self._n_days_per_month = None
return None
self._solve_monthly_time()
self._map_monthly_vars_procs()
self._declare_monthly_arrays()
def _solve_monthly_time(self) -> None:
"""Create monthly time coordinate and initialize day counter."""
import pandas as pd
ctl = self._control
self._time_months = pd.date_range(
start=ctl.start_time, end=ctl.end_time, freq="MS"
).values.astype("datetime64[M]")
self._n_days_per_month = xr.DataArray(
data=np.zeros(len(self._time_months), dtype="int32"),
dims=["month"],
coords={"month": self._time_months},
attrs=dict(
description="Number of days in each month",
units="days",
),
name="days per month",
)
def _map_monthly_vars_procs(self) -> None:
"""Map monthly variables to their source processes.
Raises
------
ValueError
If any requested variable is not found in model processes
"""
self._monthly_vars_procs = {}
for vv in self._monthly_accum_var_list:
for pp in self._model.processes.keys():
proc_vars = self._model.processes[pp].get_variables()
if vv in proc_vars:
self._monthly_vars_procs[vv] = pp
elif (
isinstance(
proc := self._model.processes[pp],
FlowGraph,
)
and vv in proc["_addtl_output_vars"]
):
self._monthly_vars_procs[vv] = pp
if not set(self._monthly_vars_procs.keys()) == set(
self._monthly_accum_var_list
):
raise ValueError(
"Not all monthly accumulation variables were found among the "
"model processes."
)
def _declare_monthly_arrays(self):
"""Declare xarray DataArrays for monthly accumulations."""
self._monthly_arrays = {}
for vv in self._monthly_accum_var_list:
proc_name = self._monthly_vars_procs[vv]
proc = self._model.processes[proc_name]
var_meta = meta.find_variables(vv)
if (
not var_meta
and hasattr(proc, "_addtl_output_vars")
and vv in proc._addtl_output_vars
):
var_meta = proc.meta[vv]
var_meta["desc"] = vv
var_meta["units"] = "unknown"
else:
var_meta = var_meta[vv]
# <
spatial_dim_len = proc[vv].shape[0]
spatial_dim_name = var_meta["dims"][0]
spatial_coord_name = spatial_dim_to_coord_name[spatial_dim_name]
spatial_coord = proc._params.coords[spatial_coord_name]
new_shape = (len(self.time_months), spatial_dim_len)
self._monthly_arrays[vv] = xr.DataArray(
# zeros required for accumulations
data=np.zeros(new_shape, dtype=var_meta["type"]),
dims=["month", spatial_dim_name],
coords={
"month": self._time_months,
spatial_coord_name: ([spatial_dim_name], spatial_coord),
},
# reference_time=reference_time,
attrs=dict(
description=var_meta["desc"],
units=var_meta["units"],
resolution="Monthly",
),
name=f"{vv} monthly accumulation",
)
# Can not really be put into month resolution.
# self._monthly_arrays[vv].month.attrs["units"] = "M"
def _get_month_index(self) -> int:
"""Determine current month index for accumulation."""
current_month = self._current_time.astype("datetime64[M]")
self._current_month_index = np.where(
self._time_months == current_month
)[0][0]
return self._current_month_index
def _accumulate_monthly_values(self) -> None:
"""Accumulate current timestep values into monthly arrays."""
if not self._monthly_accum_var_list:
return
self._get_month_index()
mon_ind = self._current_month_index
self._n_days_per_month.values[mon_ind] += 1
for vv in self._monthly_accum_var_list:
proc_name = self._monthly_vars_procs[vv]
self._monthly_arrays[vv][mon_ind, :] += self._model.processes[
proc_name
][vv]
# ==== NOI + HRU SUB section =====================
def _init_noi_sub(self) -> None:
"""Initialize NOI and HRU subset data structures."""
if not self._noi_var_list and not self._hoi_var_list:
# Initialize empty lists so other methods don't error
self._noi_hoi_data_list = []
self._noi_hoi_stats_list = []
return None
self._solve_time()
self._solve_noi_stat_list()
self._solve_hoi_stat_list()
self._map_noi_vars_procs()
self._map_hoi_vars_procs()
# Initialize empty dicts first
self._noi_arrays = {}
self._hoi_arrays = {}
# Build iteration lists (references the empty dicts)
self._build_noi_hoi_iteration_lists()
# Now populate the arrays
self._declare_noi_hoi_arrays()
def _solve_time(self) -> None:
"""Create daily time coordinate for full time series data."""
import pandas as pd
ctl = self._control
self._time = pd.date_range(
start=ctl.start_time, end=ctl.end_time, freq="D"
).values.astype("datetime64[D]")
def _solve_noi_stat_list(self) -> None:
"""Build NOI statistics dict from function: var_list mapping."""
self._noi_stat_func_vars = {}
if self._noi_stats is None:
return
for func, var_list in self._noi_stats.items():
if not callable(func):
raise ValueError("noi_stats keys must be callable functions.")
self._noi_stat_func_vars[func] = var_list
def _solve_hoi_stat_list(self) -> None:
"""Build HRU subset statistics dict from function: var_list mapping."""
self._hoi_stat_func_vars = {}
if self._hoi_stats is None:
return
for func, var_list in self._hoi_stats.items():
if not callable(func):
raise ValueError("hoi_stats keys must be callable functions.")
self._hoi_stat_func_vars[func] = var_list
@staticmethod
def _solve_flowgraph_inds(tup_list, params, check=True):
"""Resolve FlowGraph node indices from (node_maker_name,
node_maker_id) tuples.
For FlowGraph nodes, IDs are specified as 2-tuples:
(node_maker_name, node_maker_id) rather than simple integer IDs.
Parameters
----------
tup_list : list of tuple
List of (node_maker_name, node_maker_id) tuples
params : dict
Process parameters containing node_maker_name and node_maker_id
arrays
check : bool, optional
Whether to validate results (default True)
Returns
-------
list of int
Indices matching the requested tuples
"""
flowgraph_inds = []
for tup in tup_list:
matches = np.where(
(params["node_maker_name"] == tup[0])
& (params["node_maker_id"] == tup[1])
)[0]
if len(matches) == 0:
raise ValueError(
f"FlowGraph node not found: "
f"node_maker_name='{tup[0]}', node_maker_id={tup[1]}"
)
flowgraph_inds += matches.tolist()
if check:
# Validate: check that found indices match requested tuples
found_names = params["node_maker_name"][flowgraph_inds].tolist()
expected_names = [tt for tt, _ in tup_list]
if found_names != expected_names:
raise ValueError(
f"FlowGraph index resolution failed: "
f"node_maker_name mismatch. Expected {expected_names}, "
f"got {found_names}"
)
found_ids = params["node_maker_id"][flowgraph_inds].tolist()
expected_ids = [tt for _, tt in tup_list]
if found_ids != expected_ids:
raise ValueError(
f"FlowGraph index resolution failed: "
f"node_maker_id mismatch. Expected {expected_ids}, "
f"got {found_ids}"
)
return flowgraph_inds
def _map_noi_vars_procs(self) -> None:
"""Map NOI variables to processes and resolve indices."""
if self._noi_var_list is None:
return
self._noi_vars_procs = {}
# Map variables to processes
for vv in self._noi_var_list:
for pp in self._model.processes.keys():
proc = self._model.processes[pp]
proc_vars = proc.get_variables()
if hasattr(proc, "_addtl_output_vars"):
proc_vars += proc._addtl_output_vars
if vv in proc_vars:
self._noi_vars_procs[vv] = pp
vv_dims = proc.meta[vv]["dims"][0]
if vv_dims != "nsegment" and vv_dims != "nnodes":
raise ValueError(
f"Variable '{vv}' does not have dimension "
"'nsegment' nor 'nnodes'."
)
if not self._noi_vars_procs:
return
# Handle dict mode (per-variable IDs) or list mode (same IDs for all)
# IDs can be simple integers (nhm_seg) or tuples for FlowGraph nodes
if isinstance(self._noi_ids, dict):
# Dict mode: per-variable IDs
self._noi_inds = {}
for vv in self._noi_var_list:
proc_name = self._noi_vars_procs[vv]
proc = self._model.processes[proc_name]
if not isinstance(self._noi_ids[vv][0], tuple):
self._noi_inds[vv] = np.where(
np.isin(
proc._params.parameters["nhm_seg"],
self._noi_ids[vv],
)
)[0]
else:
tup_list = self._noi_ids[vv]
self._noi_inds[vv] = self._solve_flowgraph_inds(
tup_list, proc._params.parameters
)
else:
# List mode: same IDs for all variables
proc_name = list(self._noi_vars_procs.values())[0]
proc = self._model.processes[proc_name]
if not isinstance(self._noi_ids, tuple):
# proc_coord = spatial_dim_to_coord_name[
# list(proc._params.dims.keys())[0]
# ]
self._noi_inds = np.where(
np.isin(
proc._params.parameters["nhm_seg"],
self._noi_ids,
)
)[0]
else:
tup_list = self._noi_ids
self._noi_inds = self._solve_flowgraph_inds(
tup_list, proc._params.parameters
)
def _map_hoi_vars_procs(self) -> None:
"""Map HRU subset variables to processes and resolve indices."""
if self._hoi_var_list is None:
return
self._hoi_vars_procs = {}
# Map variables to processes
for vv in self._hoi_var_list:
for pp in self._model.processes.keys():
proc_vars = self._model.processes[pp].get_variables()
if vv in proc_vars:
self._hoi_vars_procs[vv] = pp
vv_dims = meta.find_variables(vv)[vv]["dims"][0]
if vv_dims != "nhru":
raise ValueError(
f"Variable '{vv}' does not have dimension 'nhru'."
)
if not self._hoi_vars_procs:
return
# Handle dict mode (per-variable IDs) or list mode (same IDs for all)
if isinstance(self._hoi_ids, dict):
# Dict mode: per-variable IDs
self._hoi_inds = {}
for vv in self._hoi_var_list:
proc_name = self._hoi_vars_procs[vv]
self._hoi_inds[vv] = np.where(
np.isin(
self._model.processes[proc_name]._params.parameters[
"nhm_id"
],
self._hoi_ids[vv],
)
)[0]
else:
# List mode: same IDs for all variables
proc_name = list(self._hoi_vars_procs.values())[0]
self._hoi_inds = np.where(
np.isin(
self._model.processes[proc_name]._params.parameters[
"nhm_id"
],
self._hoi_ids,
)
)[0]
def _build_noi_hoi_iteration_lists(self) -> None:
"""Build and cache iteration lists to avoid rebuilding each
timestep."""
# For _add_noi_hoi_data (called every timestep)
self._noi_hoi_data_list = []
if self._noi_var_list is not None:
self._noi_hoi_data_list.append(
(
self._noi_arrays,
self._noi_var_list,
self._noi_vars_procs,
self._noi_inds,
)
)
if self._hoi_var_list is not None:
self._noi_hoi_data_list.append(
(
self._hoi_arrays,
self._hoi_var_list,
self._hoi_vars_procs,
self._hoi_inds,
)
)
# For _calculate_noi_hoi_stats (called once at finalization)
self._noi_hoi_stats_list = []
if self._noi_stats is not None:
self._noi_hoi_stats_list.append(
(
"noi", # marker to identify which stats dict to use
self._noi_arrays,
self._noi_stat_func_vars,
)
)
if self._hoi_stats is not None:
self._noi_hoi_stats_list.append(
(
"hoi", # marker to identify which stats dict to use
self._hoi_arrays,
self._hoi_stat_func_vars,
)
)
def _declare_noi_hoi_arrays(self) -> None:
"""Declare xarray DataArrays for NOI and HRU subset variables."""
# Use cached iteration list instead of rebuilding
for arrays, var_list, vars_procs, inds in self._noi_hoi_data_list:
for vv in var_list:
proc_name = vars_procs[vv]
proc = self._model.processes[proc_name]
var_meta = meta.find_variables(vv)
if (
not var_meta
and hasattr(proc, "_addtl_output_vars")
and vv in proc._addtl_output_vars
):
var_meta = proc.meta[vv]
var_meta["desc"] = vv
var_meta["units"] = "unknown"
else:
var_meta = var_meta[vv]
# Handle dict mode (per-variable indices) or list mode
var_inds = inds[vv] if isinstance(inds, dict) else inds
spatial_dim_len = proc[vv][var_inds].shape[0]
spatial_dim_name = var_meta["dims"][0]
spatial_coord_name = spatial_dim_to_coord_name[
spatial_dim_name
]
spatial_coord = proc._params.coords[spatial_coord_name][
var_inds
]
coords = {
"time": self._time,
spatial_coord_name: (
[spatial_dim_name],
spatial_coord,
),
}
if spatial_coord_name == "node_coord":
coords["node_maker_name"] = (
[spatial_dim_name],
proc._params.parameters["node_maker_name"][var_inds],
)
coords["node_maker_id"] = (
[spatial_dim_name],
proc._params.parameters["node_maker_id"][var_inds],
)
# <
new_shape = (len(self._time), spatial_dim_len)
arrays[vv] = xr.DataArray(
data=np.full(new_shape, np.nan, dtype=var_meta["type"]),
dims=["time", spatial_dim_name],
coords=(coords),
attrs=dict(
description=var_meta["desc"],
units=var_meta["units"],
),
name=vv,
)
def _add_noi_hoi_data(self) -> None:
"""Add current timestep data to NOI and HRU subset arrays."""
if not self._noi_hoi_data_list:
return
time_ind = self._control.itime_step
for arrays, var_list, vars_procs, inds in self._noi_hoi_data_list:
for vv in var_list:
proc_name = vars_procs[vv]
# Handle dict mode (per-variable indices) or list mode
var_inds = inds[vv] if isinstance(inds, dict) else inds
arrays[vv][time_ind, :] = self._model.processes[proc_name][vv][
var_inds
]
def _calculate_noi_hoi_stats(self) -> None:
"""Calculate NOI/HOI statistics, store as stats[var][func_name]."""
self._noi_stats_results = {}
self._hoi_stats_results = {}
for stats_type, arrays, stat_func_vars in self._noi_hoi_stats_list:
# Get the appropriate stats dict
stats = (
self._noi_stats_results
if stats_type == "noi"
else self._hoi_stats_results
)
for func, var_list in stat_func_vars.items():
func_name = func.__name__
for vv in var_list:
if vv not in arrays:
continue
# Create nested dict structure: stats[var][stat_name]
if vv not in stats:
stats[vv] = {}
# Calculate the statistic
result = func(arrays[vv])
# Set the name to the variable name
result.name = vv
# Add metadata attributes
result.attrs["variable"] = vv
result.attrs["statistic"] = func_name
# Add period of record from original time series
time_coord = arrays[vv].coords["time"]
period_start = str(time_coord.values[0])
period_end = str(time_coord.values[-1])
result.attrs["period_of_record"] = (
f"{period_start} to {period_end}"
)
stats[vv][func_name] = result
# ==== Zarr Methods ================
def _add_zarr_data(self) -> None:
"""Add current timestep data to zarr chunk buffers."""
if not self._chunked_data_list:
return
buffer_ind = self._control.itime_step % self._chunk_time
for vv, proc_name in self._chunked_data_list:
var_obj = self._model.processes[proc_name][vv]
# Handle TimeseriesArray objects
if hasattr(var_obj, "current"):
self._zarr_buffers[vv][buffer_ind, :] = var_obj.current
else:
self._zarr_buffers[vv][buffer_ind, :] = var_obj
# Write buffer to zarr when full
if (self._control.itime_step + 1) % self._chunk_time == 0:
self._write_zarr_buffer()
def _build_chunked_iteration_list(self) -> None:
"""Build and cache iteration list for zarr chunked output."""
self._chunked_data_list = []
if self._chunked_var_list is not None:
for vv in self._chunked_var_list:
proc_name = self._chunked_vars_procs[vv]
self._chunked_data_list.append((vv, proc_name))
def _init_zarr_chunked(
self,
chunked_var_list: list[str] | None,
chunked_output_file: str | pl.Path | None,
chunk_sizes: dict[str, int] | None,
chunk_size_auto_warn: bool,
) -> None:
"""Initialize zarr chunked output."""
self._chunked_var_list = chunked_var_list
self._chunked_output_file = chunked_output_file
self._zarr_buffers = {}
self._zarr_ds = None
self._zarr_initialized = False
if chunked_var_list is None:
return
if chunked_output_file is None:
raise ValueError(
"chunked_output_file must be provided when "
"chunked_var_list is specified"
)
self._chunked_output_file = pl.Path(chunked_output_file)
self._zarr_store = None # Will be opened after file initialization
# Determine which processes own which variables
self._chunked_vars_procs = {}
for vv in chunked_var_list:
found = False
for proc_name, proc in self._model.processes.items():
if vv in proc.variables:
self._chunked_vars_procs[vv] = proc_name
found = True
break
if not found:
raise ValueError(
f"Variable '{vv}' not found in any model process"
)
# Auto-determine chunk sizes if not provided
if chunk_sizes is None:
chunk_sizes = self._auto_determine_chunk_sizes(
chunk_size_auto_warn
)
self._chunk_sizes = chunk_sizes
self._chunk_time = chunk_sizes["time"]
# Get variable metadata for dtypes
var_metadata = meta.find_variables(chunked_var_list)
# Initialize buffers for each variable
for vv in chunked_var_list:
proc_name = self._chunked_vars_procs[vv]
proc = self._model.processes[proc_name]
# Handle TimeseriesArray objects
var_obj = proc[vv]
if hasattr(var_obj, "current"):
# TimeseriesArray - use .current.shape
spatial_shape = var_obj.current.shape
else:
# Regular array - use .shape
spatial_shape = var_obj.shape
# Get dtype from metadata, default to float64
if vv in var_metadata and "type" in var_metadata[vv]:
yaml_type = var_metadata[vv]["type"]
dtype = var_type_to_numpy_type.get(yaml_type, np.float64)
else:
dtype = np.float64
# Create buffer: (chunk_time, ...spatial_dims)
buffer_shape = (self._chunk_time,) + spatial_shape
self._zarr_buffers[vv] = np.zeros(buffer_shape, dtype=dtype)
def _auto_determine_chunk_sizes(self, warn: bool = True) -> dict[str, int]:
"""Auto-determine sensible chunk sizes based on model dimensions."""
chunk_sizes = {"time": 365} # Default: 1 year for daily data
# Get spatial dimensions from first chunked variable
first_var = self._chunked_var_list[0]
proc_name = self._chunked_vars_procs[first_var]
proc = self._model.processes[proc_name]
# Determine spatial dimensions
if hasattr(proc, "_params"):
dims = proc._params.dims
if "nhru" in dims:
chunk_sizes["nhru"] = min(5000, max(1000, dims["nhru"] // 2))
if "nsegment" in dims:
chunk_sizes["nsegment"] = min(
5000, max(1000, dims["nsegment"] // 2)
)
if "nnode" in dims:
chunk_sizes["nnode"] = min(5000, max(1000, dims["nnode"] // 2))
# Estimate chunk size in MB
bytes_per_value = 8 # float64
n_vars = len(self._chunked_var_list)
first_var_size = proc[first_var].size
chunk_mb = (
chunk_sizes["time"] * first_var_size * bytes_per_value * n_vars
) / (1024**2)
if warn:
warnings.warn(
f"Chunk sizes not specified. Using auto-determined values: "
f"{chunk_sizes}. Estimated chunk size: ~{chunk_mb:.1f} MB. "
"Consider tuning for your use case. "
"See documentation for chunking guidance.",
UserWarning,
)
return chunk_sizes
def _initialize_zarr_file(self) -> None:
"""Initialize zarr file with appropriate structure and chunks."""
if self._zarr_initialized:
return
# Build time coordinate
time_array = np.arange(
self._control.start_time,
self._control.start_time + self._time_step * self._control.n_times,
self._time_step,
).astype("datetime64[ns]")
# Get variable metadata for dtypes
var_metadata = meta.find_variables(self._chunked_var_list)
# Create dataset dict
data_vars = {}
coords = {"time": time_array}
encoding = {}
for vv in self._chunked_var_list:
proc_name = self._chunked_vars_procs[vv]
proc = self._model.processes[proc_name]
# Handle TimeseriesArray objects
var_obj = proc[vv]
if hasattr(var_obj, "current"):
# TimeseriesArray - use .current.shape
spatial_shape = var_obj.current.shape
else:
# Regular array - use .shape
spatial_shape = var_obj.shape
# Determine dimension names and add spatial coordinates
if spatial_shape == ():
dims = ["time"]
chunks_tuple = (self._chunk_sizes["time"],)
elif len(spatial_shape) == 1:
# Determine if nhru, nsegment, or nnode
if "nhm_id" in proc._params.parameters and spatial_shape[
0
] == len(proc._params.parameters["nhm_id"]):
spatial_dim = "nhru"
spatial_coord_name = "nhm_id"
spatial_coord_values = proc._params.parameters["nhm_id"]
elif "nhm_seg" in proc._params.parameters and spatial_shape[
0
] == len(proc._params.parameters["nhm_seg"]):
spatial_dim = "nsegment"
spatial_coord_name = "nhm_seg"
spatial_coord_values = proc._params.parameters["nhm_seg"]
else:
spatial_dim = "nnode"
spatial_coord_name = "node_coord"
# For nodes, create integer indices
spatial_coord_values = np.arange(spatial_shape[0])
# Add spatial coordinate
if spatial_coord_name not in coords:
coords[spatial_coord_name] = (
spatial_dim,
spatial_coord_values,
)
dims = ["time", spatial_dim]
spatial_chunk = self._chunk_sizes.get(
spatial_dim, spatial_shape[0]
)
chunks_tuple = (self._chunk_sizes["time"], spatial_chunk)
else:
raise ValueError(
f"Variable '{vv}' has unsupported shape: {spatial_shape}"
)
# Get dtype from metadata, default to float64
if vv in var_metadata and "type" in var_metadata[vv]:
yaml_type = var_metadata[vv]["type"]
dtype = var_type_to_numpy_type.get(yaml_type, np.float64)
else:
dtype = np.float64
# Create placeholder data (will be filled incrementally)
data_vars[vv] = (
dims,
np.zeros((len(time_array),) + spatial_shape, dtype=dtype),
)
# Store chunking for encoding
encoding[vv] = {"chunks": chunks_tuple}
# Create xarray Dataset
ds = xr.Dataset(data_vars, coords=coords)
# Write to zarr with chunking
# consolidated=False to avoid Zarr v3 spec warning about
# consolidated metadata
ds.to_zarr(
self._chunked_output_file,
mode="w",
encoding=encoding,
consolidated=False,
)
self._zarr_ds = xr.open_zarr(
self._chunked_output_file, consolidated=False
)
# Open zarr store for direct writing
import zarr
self._zarr_store = zarr.open(str(self._chunked_output_file), mode="r+")
self._zarr_initialized = True
def _write_zarr_buffer(self) -> None:
"""Write current buffer to zarr file."""
if not self._zarr_initialized:
self._initialize_zarr_file()
# Determine time slice for this chunk
current_step = self._control.itime_step + 1
chunk_start = (current_step // self._chunk_time - 1) * self._chunk_time
chunk_end = chunk_start + self._chunk_time
# Write directly to zarr arrays without copying
for vv in self._chunked_var_list:
self._zarr_store[vv][chunk_start:chunk_end] = self._zarr_buffers[
vv
]
def _finalize_zarr(self) -> None:
"""Finalize zarr output, writing any remaining buffered data."""
if self._chunked_var_list is None:
return
# Initialize zarr file if not already initialized
if not self._zarr_initialized:
self._initialize_zarr_file()
# Write any remaining data in buffer
remaining_steps = (self._control.itime_step + 1) % self._chunk_time
if remaining_steps > 0:
current_step = self._control.itime_step + 1
chunk_start = (current_step // self._chunk_time) * self._chunk_time
chunk_end = chunk_start + remaining_steps
for vv in self._chunked_var_list:
self._zarr_store[vv][chunk_start:chunk_end] = (
self._zarr_buffers[vv][:remaining_steps]
)
# Close zarr resources if opened
if self._zarr_ds is not None:
self._zarr_ds.close()
# Note: zarr stores don't need explicit closing, but set to None for
# clarity
self._zarr_store = None
# ==== General methods ================
[docs]
def calculate(self) -> None:
"""Collect data for current timestep (called by model.run())."""
# The control.advance() must happen before the this calculate() method.
if self._control.current_time != self._current_time + self._time_step:
raise ValueError(
"Calculation time requested does not match with control "
)
else:
self._current_time = self._control.current_time.copy()
self._accumulate_monthly_values()
self._add_noi_hoi_data()
self._add_zarr_data()
[docs]
def finalize(self) -> None:
"""Finalize and calculate statistics (called by model.run())."""
self._finalized = True
self._finalize_zarr()
self._calculate_noi_hoi_stats()
[docs]
def to_netcdf(self, output_dir: pl.Path) -> None:
"""Write output to netCDF files (not yet implemented)."""
if not self._finalized:
warnings.warn(
"Output can only be written once the Output object is "
"finalized"
)
return
# self._monthly_to_netcdf(self)
# self._noi_to_netcdf(self)
raise NotImplementedError("YET.")