Source code for ufs_plot_utils.plot

import logging
import numpy as np
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from mpl_toolkits.axes_grid1 import make_axes_locatable

logger = logging.getLogger(__name__)


[docs] class Plotter: """ Plot data using Cartopy. """ def __init__(self, cfg): self.cfg = cfg plot_cfg = cfg.get("plot", default={}) self.proj_cfg = plot_cfg.get("projection", {}) self.fig_cfg = plot_cfg.get("figure", {}) self.cb_cfg = plot_cfg.get("colorbar", {}) self.title_cfg = plot_cfg.get("title", {}) self.scatter_cfg = plot_cfg.get("scatter", {}) self.bg_cfg = plot_cfg.get("background", {}) self.style_resolver = None # Set Cartopy Natural Earth data path cartopy_ne_path = plot_cfg.get("cartopy_ne_path") if cartopy_ne_path: cartopy.config['data_dir'] = cartopy_ne_path logger.info(f'''Cartopy data_dir set to: {cartopy_ne_path}''') # =============================================================== CHJ ===
[docs] def set_style_resolver(self, resolver): self.style_resolver = resolver
# =============================================================== CHJ ===
[docs] def plot_data_tiles( self, lat, lon, da, varname, output_title, dataset ): """ Plot cubed-sphere tiled data. """ logger.info("Plotting seamless global map") if self.style_resolver is None: raise RuntimeError("StyleResolver not set.") if lat.shape != da.shape: raise ValueError( f'''Geo/data mismatch: lat={lat.shape}, data={da.shape}''' ) # ------------------------- # Tiles (dynamic, not hardcoded) # ------------------------- if "tile" not in da.dims: raise ValueError(f'''Expected "tile" dimension, got {da.dims}''') num_tiles = da.sizes["tile"] # ------------------------- # Projection # ------------------------- projection = self.build_projection() # ------------------------- # Figure # ------------------------- figsize = self.fig_cfg.get("figsize", [10, 5]) dpi = self.fig_cfg.get("dpi", 150) fig, ax = plt.subplots( 1, 1, figsize=figsize, dpi=dpi, subplot_kw=dict(projection=projection) ) ax.set_global() # ------------------------- # Optional regional extent # ------------------------- self.apply_extent(ax) # ------------------------- # Background # ------------------------- self.plot_background(ax) # ------------------------- # Style # ------------------------- style = self.style_resolver.resolve( varname, da ) cmap = style.cmap vmin = style.vmin vmax = style.vmax cbar_label = style.label # ------------------------- # Title # ------------------------- title_fs = self.title_cfg.get("fontsize", 8) ax.set_title(output_title, fontsize=title_fs) # ------------------------- # Plot tiles # ------------------------- cs = None for it in range(num_tiles): lon_tile = np.asarray(lon[it, :, :]) lat_tile = np.asarray(lat[it, :, :]) # Extract from xarray (not pre-converted numpy) var_tile = da.isel(tile=it).values # Wrap longitude lon_tile = (lon_tile + 180) % 360 - 180 # Mask invalid var_tile = np.ma.masked_invalid(var_tile) cs = ax.pcolormesh( lon_tile, lat_tile, var_tile, cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), shading="auto" ) # ------------------------- # Colorbar # ------------------------- cb_extend = self.cb_cfg.get("extend", "both") cb_size = self.cb_cfg.get("size", "3%") cb_pad = self.cb_cfg.get("pad", 0.1) cb_label_fs = self.cb_cfg.get("label_fontsize", 7) cb_tick_fs = self.cb_cfg.get("tick_fontsize", 6) divider = make_axes_locatable(ax) ax_cb = divider.new_horizontal( size=cb_size, pad=cb_pad, axes_class=plt.Axes ) fig.add_axes(ax_cb) cbar = plt.colorbar(cs, cax=ax_cb, extend=cb_extend) cbar.ax.tick_params(labelsize=cb_tick_fs) cbar.set_label(cbar_label, fontsize=cb_label_fs) return fig
# =============================================================== CHJ ===
[docs] def plot_data_scatter( self, lat, lon, da, varname, output_title, dataset ): """ Scatter plot for observation data """ logger.info("Plotting observation scatter") if self.style_resolver is None: raise RuntimeError("StyleResolver not set.") # ------------------------- # Projection # ------------------------- projection = self.build_projection() # ------------------------- # Figure # ------------------------- figsize = self.fig_cfg.get("figsize", [10, 5]) dpi = self.fig_cfg.get("dpi", 150) fig, ax = plt.subplots( 1, 1, figsize=figsize, dpi=dpi, subplot_kw=dict(projection=projection) ) ax.set_global() # ------------------------- # Optional regional extent # ------------------------- self.apply_extent(ax) # ------------------------- # Background # ------------------------- self.plot_background(ax) # ------------------------- # Style # ------------------------- style = self.style_resolver.resolve( varname, da ) cmap = style.cmap vmin = style.vmin vmax = style.vmax cbar_label = style.label # ------------------------- # Title # ------------------------- title_fs = self.title_cfg.get("fontsize", 8) ax.set_title(output_title, fontsize=title_fs) # ------------------------- # Scatter # ------------------------- cs = ax.scatter( lon, lat, c=da.values, cmap=cmap, vmin=vmin, vmax=vmax, s=self.scatter_cfg.get("marker_size", 5), transform=ccrs.PlateCarree() ) # ------------------------- # Colorbar # ------------------------- cb_extend = self.cb_cfg.get("extend", "both") cb_size = self.cb_cfg.get("size", "3%") cb_pad = self.cb_cfg.get("pad", 0.1) cb_label_fs = self.cb_cfg.get("label_fontsize", 7) cb_tick_fs = self.cb_cfg.get("tick_fontsize", 6) divider = make_axes_locatable(ax) ax_cb = divider.new_horizontal( size=cb_size, pad=cb_pad, axes_class=plt.Axes, ) fig.add_axes(ax_cb) cbar = plt.colorbar(cs, cax=ax_cb, extend=cb_extend) cbar.ax.tick_params(labelsize=cb_tick_fs) cbar.set_label(cbar_label, fontsize=cb_label_fs) return fig
# =============================================================== CHJ ===
[docs] def plot_data_grid( self, lat, lon, da, varname, output_title, dataset ): """ Plot regular/curvilinear 2D grid data. """ logger.info("Plotting structured grid map") if self.style_resolver is None: raise RuntimeError("StyleResolver not set.") # ------------------------- # Projection # ------------------------- projection = self.build_projection() # ------------------------- # Figure # ------------------------- figsize = self.fig_cfg.get("figsize", [10, 5]) dpi = self.fig_cfg.get("dpi", 150) fig, ax = plt.subplots( 1, 1, figsize=figsize, dpi=dpi, subplot_kw=dict(projection=projection) ) ax.set_global() # ------------------------- # Optional regional extent # ------------------------- self.apply_extent(ax) # ------------------------- # Background # ------------------------- self.plot_background(ax) # ------------------------- # Style # ------------------------- style = self.style_resolver.resolve(varname, da) cs = ax.pcolormesh( lon, lat, np.ma.masked_invalid(da.values), cmap=style.cmap, vmin=style.vmin, vmax=style.vmax, transform=ccrs.PlateCarree(), shading="auto" ) # ------------------------- # Title # ------------------------- title_fs = self.title_cfg.get("fontsize", 8) ax.set_title(output_title, fontsize=title_fs) # ------------------------- # Colorbar # ------------------------- divider = make_axes_locatable(ax) ax_cb = divider.new_horizontal( size=self.cb_cfg.get("size", "3%"), pad=self.cb_cfg.get("pad", 0.1), axes_class=plt.Axes, ) fig.add_axes(ax_cb) cbar = plt.colorbar( cs, cax=ax_cb, extend=self.cb_cfg.get("extend", "both") ) cbar.ax.tick_params( labelsize=self.cb_cfg.get("tick_fontsize", 6) ) cbar.set_label( style.label, fontsize=self.cb_cfg.get("label_fontsize", 7) ) return fig
# =============================================================== CHJ ===
[docs] def build_projection(self): proj_name = self.proj_cfg.get("name", "Robinson") central_lon = self.proj_cfg.get( "central_longitude", -77.0369 ) proj_map = { "Robinson": ccrs.Robinson, "PlateCarree": ccrs.PlateCarree, "Mollweide": ccrs.Mollweide, "NorthPolarStereo": ccrs.NorthPolarStereo, "SouthPolarStereo": ccrs.SouthPolarStereo, "Stereographic": ccrs.Stereographic, } proj_class = proj_map.get( proj_name, ccrs.Robinson ) if proj_name == "PlateCarree": return ccrs.PlateCarree() elif proj_name == "Stereographic": return ccrs.Stereographic( central_longitude=central_lon, central_latitude=self.proj_cfg.get( "central_latitude", 90, ), ) else: return proj_class( central_longitude=central_lon )
# =============================================================== CHJ ===
[docs] def apply_extent(self, ax): extent_cfg = self.cfg.get("plot", "extent") if not extent_cfg: return ax.set_extent( [ extent_cfg["lon"][0], extent_cfg["lon"][1], extent_cfg["lat"][0], extent_cfg["lat"][1], ], crs=ccrs.PlateCarree(), )
# =============================================================== CHJ ===
[docs] def plot_background(self, ax): """ Add background features (config-driven) """ features = set(self.bg_cfg.get("features", [])) res = self.bg_cfg.get("resolution", "50m") lw = self.bg_cfg.get("linewidth", 0.5) alpha = self.bg_cfg.get("alpha", 0.7) logger.info(f'''Background features: {features}''') if "coastline" in features: ax.add_feature( cfeature.COASTLINE.with_scale(res), linewidth=lw, alpha=alpha ) if "borders" in features: ax.add_feature( cfeature.BORDERS.with_scale(res), linewidth=lw, alpha=alpha ) if "states" in features: ax.add_feature( cfeature.STATES.with_scale(res), linewidth=lw, linestyle=":", alpha=alpha ) if "lakes" in features: ax.add_feature( cfeature.LAKES.with_scale(res), linewidth=lw, facecolor="none", edgecolor="blue", alpha=alpha ) if "land" in features: ax.add_feature( cfeature.LAND.with_scale(res), facecolor=cfeature.COLORS["land"], edgecolor="face", alpha=alpha )