Source code for pywatershed.analysis.model_graph

import pathlib as pl
import tempfile

from ..base.conservative_process import ConservativeProcess
from ..base.model import Model
from ..utils.optional_import import import_optional_dependency
from .utils.colorbrewer import nhm_process_colors


[docs] class ModelGraph: """Visualize a pywatershed Model as a directed graph. Creates a GraphViz visualization showing processes, their inputs/variables, and the data flow between them. Uses the dot layout engine with left-to-right orientation. Args: model: The pywatershed Model to visualize show_params: Whether to display parameters in process nodes process_colors: Dictionary mapping process names to colors for node borders node_penwidth: Width of node borders (default: 2) edge_penwidth: Width of edges/arrows (default: 1.5) edge_arrowsize: Size of arrowheads (default: 1.2) default_edge_color: Color for edges between processes (default: "black") from_file_edge_color: Color for edges from file inputs (default: same as default_edge_color) hide_variables: If True, hide individual variable names and only show process-to-process connections (default: True) Example: >>> import pywatershed as pws >>> model = pws.Model(process_list, control=control, parameters=params) >>> mg = pws.analysis.ModelGraph(model, hide_variables=False) >>> mg.SVG() """
[docs] def __init__( self, model: Model, show_params: bool = False, process_colors: dict[str, str] | None = None, node_penwidth: int = 2, edge_penwidth: float = 1.5, edge_arrowsize: float = 1.2, default_edge_color: str = "black", from_file_edge_color: str | None = None, hide_variables: bool = True, ) -> None: self.pydot = import_optional_dependency("pydot") self.graph = None self.model = model self.show_params = show_params # Auto-generate process colors if not provided if process_colors is None: self.process_colors = nhm_process_colors(model) else: self.process_colors = process_colors self.node_penwidth = node_penwidth self.edge_penwidth = edge_penwidth self.edge_arrowsize = edge_arrowsize self.default_edge_color = default_edge_color if not from_file_edge_color: self.from_file_edge_color = default_edge_color else: self.from_file_edge_color = from_file_edge_color self.hide_variables = hide_variables
[docs] def build_graph(self) -> None: """Build the GraphViz graph structure. Creates nodes for all processes and files, establishes connections between them, and sets up invisible ordering edges to enforce left-to-right layout according to model.process_order. The graph uses the dot layout engine with: - Left-to-right orientation (rankdir="LR") - Invisible edges to control node ordering - constraint=false on data edges so they don't affect layout """ # Build the process nodes self.process_nodes = {} for process in self.model.process_order: self.process_nodes[process] = self._process_node( process, self.model.processes[process], show_params=self.show_params, ) # Build connections between processes and files self.files = [] self.connections = [] for process in self.model.process_order: frm_already = [] for var, frm in self.model.process_input_from[process].items(): var_con = f":{var}" if self.hide_variables: # When hiding variables, connect directly to process # and avoid duplicate edges from same source var_con = "" if frm in frm_already: continue else: frm_already += [frm] if not isinstance(frm, pl.Path): color = self.default_edge_color if self.process_colors: color = self.process_colors[frm] self.connections += [ ( f"{frm}{var_con}", f"{process}{var_con}", color, ) ] else: file_name = frm.name self.files += [file_name] self.connections += [ ( f"Files:{file_name.split('.')[0]}", f"{process}{var_con}", self.from_file_edge_color, ) ] # Build the file inputs node self.file_node = self._file_node(self.files) # Create the GraphViz graph with layout settings self.graph = self.pydot.Dot( graph_type="digraph", layout="dot", rankdir="LR", nodesep="0.5", ranksep="1.0", ) # Add an invisible source node to force Files to leftmost position source_node = self.pydot.Node( "_source_", style="invis", width="0", height="0", ) self.graph.add_node(source_node) self.graph.add_node(self.file_node) # Add invisible edge from source to Files to control rank self.graph.add_edge( self.pydot.Edge("_source_", "Files", style="invis") ) for process in self.model.process_order: self.graph.add_node(self.process_nodes[process]) # Add invisible edges to enforce left-to-right process ordering. # These edges control the layout while constraint=false on data # edges prevents them from affecting node positions. if len(self.model.process_order) > 0: first_process = self.model.process_order[0] self.graph.add_edge( self.pydot.Edge( "Files", first_process, style="invis", weight="10" ) ) for i in range(len(self.model.process_order) - 1): from_process = self.model.process_order[i] to_process = self.model.process_order[i + 1] self.graph.add_edge( self.pydot.Edge( from_process, to_process, style="invis", weight="10" ) ) # Add data flow edges with constraint=false so they don't affect layout for con in self.connections: self.graph.add_edge( self.pydot.Edge( con[0], con[1], color=con[2], constraint="false", penwidth=str(self.edge_penwidth), arrowsize=str(self.edge_arrowsize), ) )
[docs] def SVG( self, verbose: bool = False, dpi: int = 45, show_legend: bool = False ) -> None: """Render and display the graph as SVG in a Jupyter notebook. Args: verbose: If True, print the temporary file path dpi: Resolution for the SVG output (default: 45) show_legend: If True, display legend for process colors and mass/energy budget terms before the graph """ ipdisplay = import_optional_dependency("IPython.display") if show_legend: self.display_legend() tmp_file = pl.Path(tempfile.NamedTemporaryFile().name) if self.graph is None: self.build_graph() try: self.graph.write_svg(tmp_file, prog=["dot", f"-Gdpi={dpi}"]) if verbose: print(f"Displaying SVG written to temp file: {tmp_file}") ipdisplay.display(ipdisplay.SVG(tmp_file)) except Exception as e: print( "GraphViz rendering failed. This can happen on some " "machines or with certain graph configurations." ) print(f"Error details: {e}") print( "\nTip: Try installing or updating GraphViz, or use " "show_legend=True to see the model structure without " "the visualization." )
def _process_node( self, process_name: str, process, show_params: bool = False ): """Create a node for a process with its inputs, variables, and parameters. Args: process_name: Name of the process in the model process: The process instance show_params: Whether to include parameters in the node Returns: A pydot.Node with an HTML-like table label showing the process structure """ inputs = process.get_inputs() variables = process.get_variables() params = process.get_parameters() cls = process.__class__.__name__ label = ( f'<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="1">\n' # noqa: E501 f' <TR><TD COLSPAN="6">"{process_name}"</TD></TR>\n' f' <TR><TD COLSPAN="6">{cls}</TD></TR>\n' ) category_colors = { "inputs": "white", "params": "whitesmoke", "variables": "lightgray", } show_categories = { "inputs": inputs, "params": params, "variables": variables, } if not show_params: _ = show_categories.pop("params") if isinstance(process, ConservativeProcess): mass_budget_vars = [ var for comp, vars in process.get_mass_budget_terms().items() for var in vars ] energy_budget_vars = [ var for comp, vars in process.get_energy_budget_terms().items() for var in vars ] else: mass_budget_vars = [] energy_budget_vars = [] for varset_name, varset in show_categories.items(): n_vars = len(varset) if not n_vars: continue varset_col_span = 2 if self.hide_variables: varset = [] n_vars = 0 varset_col_span = 6 label += " <TR>\n" label += f' <TD COLSPAN="{varset_col_span}" ROWSPAN="{n_vars + 1}"' # noqa: E501 label += f' BGCOLOR="{category_colors[varset_name]}">{varset_name}</TD>\n' # noqa: E501 label += " </TR>\n" for vv in sorted(varset): # Determine background color based on budget category bg_color = category_colors[varset_name] if vv in mass_budget_vars: bg_color = "lightsteelblue" elif vv in energy_budget_vars: bg_color = "lightcoral" label += " <TR>\n" label += f' <TD COLSPAN="4" BGCOLOR="{bg_color}" PORT="{vv}"><FONT POINT-SIZE="9.0">{vv}</FONT></TD>\n' # noqa: E501 label += " </TR>\n" label += "</TABLE>>\n" color_str = "" if self.process_colors: color_str = f'"{self.process_colors[process_name]}"' node = self.pydot.Node( process_name, label=label, shape="box", color=color_str, penwidth=f'"{self.node_penwidth}"', ) return node def _file_node(self, files: list[str]): """Create a node listing all input files. Args: files: List of file names that provide inputs to the model Returns: A pydot.Node with an HTML-like table label listing all files """ files = list(set(files)) label = ( '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="1">\n' # noqa ' <TR><TD COLSPAN="1">Files</TD></TR>\n' ) for file in files: label += " <TR>\n" label += f' <TD COLSPAN="1" BGCOLOR="gray50" PORT="{file.split(".")[0]}"><FONT POINT-SIZE="9.0">{file}</FONT></TD>\n' # noqa label += " </TR>\n" label += "</TABLE>>\n" node = self.pydot.Node( "Files", label=label, shape="note", color=f'"{self.from_file_edge_color}"', penwidth=f'"{self.node_penwidth}"', ) return node
[docs] def display_legend(self): """Display a complete legend for the ModelGraph visualization. Shows: - Process colors (node border colors) - Mass budget variable colors (lightsteelblue background, if present) - Energy budget variable colors (lightcoral background, if present) """ ipdisplay = import_optional_dependency("IPython.display") legend_parts = [] # Process colors legend if self.process_colors: legend_parts.append("**Process Colors:**") for process_name, color in self.process_colors.items(): legend_parts.append( f'<span style="font-family: monospace">{process_name}: ' f'<span style="color: {color}">████████</span></span>' ) legend_parts.append("") # Blank line # Check if model has mass or energy budget variables has_mass_budget = False has_energy_budget = False for process in self.model.processes.values(): if isinstance(process, ConservativeProcess): mass_terms = process.get_mass_budget_terms() if any(mass_terms.values()): has_mass_budget = True energy_terms = process.get_energy_budget_terms() if any(energy_terms.values()): has_energy_budget = True # Budget terms legend - only show if present budget_items = {} if has_mass_budget: budget_items["Mass budget variables"] = "lightsteelblue" if has_energy_budget: budget_items["Energy budget variables"] = "lightcoral" if budget_items: legend_parts.append("**Variable Background Colors:**") for label, color in budget_items.items(): legend_parts.append( f'<span style="font-family: monospace">{label}: ' f'<span style="background-color: {color}; ' f'padding: 2px 8px; border: 1px solid #ccc">' f"variable</span></span>" ) ipdisplay.display(ipdisplay.Markdown("<br>".join(legend_parts)))