Source code for pywatershed.analysis.process_plot

import pathlib as pl
from textwrap import wrap
from typing import Callable, Tuple, Union

import contextily as cx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.patches import Polygon
from xyzservices import TileProvider

from ..base import meta
from ..base.model import Model
from ..base.process import Process
from ..utils.optional_import import import_optional_dependency


[docs] class ProcessPlot:
[docs] def __init__( self, gis_dir: Union[str, pl.Path], hru_shp_file_name: str = "HRU_subset.shp", seg_shp_file_name: str = "Segments_subset.shp", ): gpd = import_optional_dependency("geopandas") self.gis_dir = pl.Path(gis_dir) if hru_shp_file_name is not None: self.hru_shapefile = self.gis_dir / hru_shp_file_name else: self.hru_shapefile = None if seg_shp_file_name is not None: self.seg_shapefile = self.gis_dir / seg_shp_file_name else: self.seg_shapefile = None # HRU one-time setups if self.hru_shapefile is not None: self.hru_gdf = gpd.read_file(self.hru_shapefile) # standardization manipulations based on a variety of different # conventions which have been found for the shp files if ("nhm_id" in self.hru_gdf.columns) and ( "nhru_v1_1" in self.hru_gdf.columns ): # This is borderline non-sense self.hru_gdf = ( self.hru_gdf.drop("nhm_id", axis=1) .rename(columns={"nhru_v1_1": "nhm_id"}) .set_index("nhm_id") ) elif "GRID_CODE" in self.hru_gdf.columns: self.hru_gdf = self.hru_gdf.rename( columns={"GRID_CODE": "nhm_id"} ).set_index("nhm_id") else: msg = "Unidentified shp file convention, work needed" raise ValueError(msg) # segment one-time setup if self.seg_shapefile is not None: self.seg_gdf = gpd.read_file(self.seg_shapefile) # if (self.__seg_poly.crs.name # == "USA_Contiguous_Albers_Equal_Area_Conic_USGS_version"): # print("Overriding USGS aea crs with EPSG:5070") self.seg_gdf.set_crs("EPSG:5070") self.seg_geoms_exploded = ( self.seg_gdf.explode(index_parts=True) .reset_index(level=1, drop=True) .drop("model_idx", axis=1) .rename(columns={"nsegment_v": "nhm_seg"}) .set_index("nhm_seg") ) return
[docs] def plot(self, var_name: str, process: Process, **kwargs): var_dims = list(meta.get_vars(var_name)[var_name]["dims"]) if "nsegment" in var_dims: return self.plot_seg_var(var_name, process, **kwargs) elif "nhru" in var_dims: return self.plot_hru_var(var_name, process, **kwargs) else: raise ValueError()
[docs] def plot_seg_var( self, var_name: str, process: Process, cmap: str = None, value_transform: Callable = None, figsize: tuple = (7, 10), title: str = None, aesthetic_width: bool = False, cx_map_source: TileProvider = cx.providers.CartoDB.Positron, vmin: float = None, vmax: float = None, aesthetic_width_color="darkblue", missing_kwds: dict = None, ): values = process[var_name] if value_transform is not None: values = value_transform(values) data_df = pd.DataFrame( { "nhm_seg": process._params.coords["nhm_seg"], var_name: values, } ).set_index("nhm_seg") df_plot = self.seg_geoms_exploded.join(data_df).reset_index() if aesthetic_width: ax = df_plot.plot( column=var_name, figsize=figsize, linewidth=df_plot[var_name], edgecolor=aesthetic_width_color, ) else: if vmin is None: vmin = np.nanmin(values) if vmax is None: vmax = np.nanmax(values) if cmap is None: cmap = "cool" if missing_kwds is None: missing_kwds = { "color": "lightgrey", "linewidth": 0.5, } ax = df_plot.plot( column=var_name, figsize=figsize, cmap=cmap, vmin=vmin, vmax=vmax, legend=True, missing_kwds=missing_kwds, ) cx.add_basemap( ax=ax, crs=df_plot.crs, source=cx_map_source, ) ax.set_axis_off() if title is None: title = var_name _ = ax.set_title(title) plt.show() return
[docs] def get_hru_var(self, var_name: str, model: Model): # find the process for proc_name, proc in model.processes.items(): params_vars = list(set(proc.variables) | set(proc.parameters)) if var_name in params_vars: process = proc break data_df = pd.DataFrame( { "nhm_id": process._params.coords["nhm_id"], var_name: process[var_name], } ).set_index("nhm_id") return data_df
[docs] def plot_hru_var( self, var_name: str, process: Process, data: np.ndarray = None, data_units: str = None, nhm_id: np.ndarray = None, clim: Tuple[float] = None, **kwargs, ): _ = import_optional_dependency("hvplot.pandas") ccrs = import_optional_dependency("cartopy.crs") if data is None: # data_df = self.get_hru_var(var_name, model) data_df = pd.DataFrame( { "nhm_id": process._params.coords["nhm_id"], var_name: process[var_name], } ).set_index("nhm_id") else: if nhm_id is None: # nhm_id = model.parameters["nhm_id"] raise ValueError("code needs work to handle nhm_id=None") data_df = pd.DataFrame( { "nhm_id": nhm_id, var_name: data, } ).set_index("nhm_id") plot_df = self.hru_gdf.join(data_df) metadata = meta.get_vars(var_name) if not len(metadata): metadata = meta.get_params(var_name) if len(metadata): metadata = metadata[var_name] else: metadata = None frame_height = 550 title = f'"{var_name}"\n' clabel = data_units if metadata is not None: title += "\n".join( wrap( f"{metadata['desc']}, {metadata['units']}", width=frame_height / 10, ) ) clabel = f"{metadata['units']}" args = { "tiles": True, "crs": ccrs.epsg(5070), "frame_height": frame_height, "c": var_name, "line_width": 0, "alpha": 0.75, "hover_cols": ["nhm_id"], "title": title, "clabel": clabel, "xlabel": "Longitude (degrees East)", "ylabel": "Latitude (degrees North)", } | kwargs if clim is not None: args["clim"] = clim plot = plot_df.hvplot(**args) return plot
def plot_line_collection( ax, geoms, values=None, cmap=None, norm=None, vary_width=False, vary_color=True, colors=None, alpha=1.0, linewidth=1.0, **kwargs, ): """Plot a collection of line geometries""" shapely = import_optional_dependency("shapely") lines = [] for geom in geoms: a = np.asarray(geom.coords) if geom.has_z: a = shapely.geometry.LineString(zip(*geom.xy)) lines.append(shapely.geometry.LineString(a)) if vary_width: lwidths = ((values / values.max()).to_numpy() + 0.01) * linewidth if vary_color: lines = LineCollection( lines, linewidths=lwidths, cmap=cmap, norm=norm, alpha=alpha ) else: lines = LineCollection( lines, linewidths=lwidths, colors=colors, alpha=alpha ) elif vary_color: lines = LineCollection( lines, linewidth=linewidth, alpha=alpha, cmap=cmap, norm=norm ) if vary_color and values is not None: lines.set_array(values) # lines.set_cmap(cmap) ax.add_collection(lines, autolim=True) ax.autoscale_view() return lines def plot_polygon_collection( ax, geoms, values=None, cmap=None, norm=None, facecolor=None, edgecolor=None, alpha=1.0, linewidth=1.0, **kwargs, ): """Plot a collection of Polygon geometries""" # from https://stackoverflow.com/questions/33714050/ # geopandas-plotting-any-way-to-speed-things-up shapely = import_optional_dependency("shapely") patches = [] for poly in geoms: a = np.asarray(poly.exterior) if poly.has_z: a = shapely.geometry.Polygon(zip(*poly.exterior.xy)) patches.append(Polygon(a)) patches = PatchCollection( patches, facecolor=facecolor, linewidth=linewidth, edgecolor=edgecolor, alpha=alpha, cmap=cmap, norm=norm, ) if values is not None: patches.set_array(values) # patches.set_cmap(cmap) ax.add_collection(patches, autolim=True) ax.autoscale_view() return patches