Source code for ufs_plot_utils.geo

import os
import logging
import numpy as np
import xarray as xr
from .utils import (
    extract_tile_prefix,
    normalize_geo_dims,
    resolve_mom6_geo_vars,
    resolve_cice_geo_vars,
)

logger = logging.getLogger(__name__)


[docs] class GeoReader: """ Handle geographic data (lat/lon), supports file or tile format. """ def __init__(self, dataset): self.dataset = dataset # =============================================================== CHJ ===
[docs] def get_geo(self, da=None): """ Choose geo data reading method based on config """ # Observation if self.dataset.data_kind == "observation": return self._get_geo_observation() # MOM6 if self.dataset.data_model == "mom6": return self._get_geo_mom6(da) # CICE if self.dataset.data_model == "cice": return self._get_geo_cice(da) geo_type = self.dataset.geo_file_type.lower() if geo_type == "file": return self._get_geo_file() elif geo_type == "orog": return self._get_geo_orog() elif geo_type == "tile": return self._get_geo_tile() else: raise ValueError(f'''Unknown geo type: {geo_type}''')
# =============================================================== CHJ === def _get_geo_file(self): """ Extract latitude and longitude arrays from input geo file. """ fpath = os.path.join( self.dataset.geo_path, self.dataset.geo_filename ) logger.info(f'''Opening geo file: {fpath}''') with xr.open_dataset( fpath, decode_timedelta=True, ) as ds_geo: # flatten groups ds_flat = xr.Dataset({k: v for k, v in ds_geo.data_vars.items()}) lat_candidates = ["lat", "latitude"] lon_candidates = ["lon", "longitude"] lat_name = next((v for v in lat_candidates if v in ds_flat), None) lon_name = next((v for v in lon_candidates if v in ds_flat), None) if lat_name is None or lon_name is None: raise ValueError("Could not detect lon/lat in OBS file") lat = ds_flat[lat_name] lon = ds_flat[lon_name] # Handle 1D case if lat.ndim == 1 and lon.ndim == 1: lon2d, lat2d = np.meshgrid(lon.values, lat.values) else: lat2d = lat.values lon2d = lon.values # Normalize geo dimensions lat_all, lon_all = normalize_geo_dims(lat2d, lon2d) logger.info(f'''lat shape={lat.shape}, lon shape={lon.shape}''') return lat_all, lon_all # =============================================================== CHJ === def _get_geo_orog(self): """ Read 6 orography tile files and return lat/lon arrays: lat(tile, y, x), lon(tile, y, x) """ geo_file = self.dataset.geo_filename geo_path = self.dataset.geo_path prefix = extract_tile_prefix(geo_file) logger.info(f'''OROG:: prefix = {prefix}''') lat_tiles = [] lon_tiles = [] for itile in range(1, 7): fname = f'''{prefix}.tile{itile}.nc''' fpath = os.path.join(geo_path, fname) if not os.path.exists(fpath): raise FileNotFoundError( f'''Orography tile file not found: {fpath}''' ) logger.info(f'''Reading orography tile {itile}: {fpath}''') with xr.open_dataset( fpath, decode_timedelta=True, ) as ds: lat_candidates = ["geolat", "y", "lat", "latitude"] lon_candidates = ["geolon", "x", "lon", "longitude"] lat_name = next( (v for v in lat_candidates if v in ds.variables), None, ) lon_name = next( (v for v in lon_candidates if v in ds.variables), None, ) if lat_name is None or lon_name is None: raise ValueError(f'''lat/lon not found in {fpath}''') lat_tiles.append(ds[lat_name].values) lon_tiles.append(ds[lon_name].values) lat_all = np.stack(lat_tiles, axis=0) lon_all = np.stack(lon_tiles, axis=0) # Normalize geo dimensions lat_all, lon_all = normalize_geo_dims(lat_all, lon_all) logger.info(f'''Geo lat shape: {lat_all.shape}''') logger.info(f'''Geo lon shape: {lon_all.shape}''') return lat_all, lon_all # =============================================================== CHJ === def _get_geo_tile(self): """ Extract lat/lon from tiled data files (grid_xt/grid_yt or similar). Returns: lat(tile, y, x), lon(tile, y, x) """ import glob import re geo_file = self.dataset.geo_filename geo_path = self.dataset.geo_path # ------------------------- # Build tile file list # ------------------------- prefix = extract_tile_prefix(geo_file) pattern = os.path.join(geo_path, f'''{prefix}.tile*.nc''') logger.info(f'''GEO TILE pattern: {pattern}''') file_list = sorted(glob.glob(pattern)) if not file_list: raise ValueError(f'''No geo tile files found: {pattern}''') # ------------------------- # Group by forecast hour # ------------------------- fhr_map = {} for f in file_list: m = re.search(r'''\.f(\d{3})\.''', f) if m: fhr = m.group(1) else: fhr = "static" # no forecast in filename fhr_map.setdefault(fhr, []).append(f) # ------------------------- # Select one fhr (first) # ------------------------- selected_fhr = sorted(fhr_map.keys())[0] selected_files = sorted(fhr_map[selected_fhr]) logger.info(f'''Geo using forecast hour: {selected_fhr}''') if len(selected_files) != 6: raise ValueError( f'''Expected 6 tiles for f{selected_fhr}, ''' f'''found {len(selected_files)}''' ) # ------------------------- # Read lat/lon # ------------------------- lat_tiles = [] lon_tiles = [] for f in selected_files: logger.info(f'''Reading geo tile: {f}''') with xr.open_dataset( f, decode_timedelta=True, ) as ds: # ------------------------- # Candidate variable names # ------------------------- lat_candidates = ["grid_yt", "lat", "latitude", "y"] lon_candidates = ["grid_xt", "lon", "longitude", "x"] lat_name = next( (v for v in lat_candidates if v in ds.variables), None, ) lon_name = next( (v for v in lon_candidates if v in ds.variables), None, ) if lat_name is None or lon_name is None: raise ValueError(f'''lat/lon not found in {f}''') lat = ds[lat_name] lon = ds[lon_name] # ------------------------- # Handle 1D -> 2D mesh # ------------------------- if lat.ndim == 1 and lon.ndim == 1: lon2d, lat2d = np.meshgrid(lon.values, lat.values) else: lat2d = lat.values lon2d = lon.values lat_tiles.append(lat2d) lon_tiles.append(lon2d) lat_all = np.stack(lat_tiles, axis=0) lon_all = np.stack(lon_tiles, axis=0) # Normalize geo dimensions lat_all, lon_all = normalize_geo_dims(lat_all, lon_all) logger.info(f'''Geo lat shape: {lat_all.shape}''') logger.info(f'''Geo lon shape: {lon_all.shape}''') return lat_all, lon_all # =============================================================== CHJ === def _get_geo_mom6(self, da): fpath = os.path.join( self.dataset.geo_path, self.dataset.geo_filename ) lon_name, lat_name = resolve_mom6_geo_vars(da) logger.info( f'''MOM6 grid mapping:: ''' f'''{da.dims} -> {lon_name}/{lat_name}''' ) with xr.open_dataset( fpath, decode_timedelta=True, ) as ds: lon = ds[lon_name].values lat = ds[lat_name].values return normalize_geo_dims( lat, lon, add_tile_dim=False, ) # =============================================================== CHJ === def _get_geo_cice(self, da): fpath = os.path.join( self.dataset.geo_path, self.dataset.geo_filename ) lon_name, lat_name = resolve_cice_geo_vars(da) logger.info( f'''CICE grid mapping:: ''' f'''{da.name} -> {lon_name}/{lat_name}''' ) with xr.open_dataset( fpath, decode_timedelta=True, ) as ds: lon = ds[lon_name].values lat = ds[lat_name].values return normalize_geo_dims( lat, lon, add_tile_dim=False, ) # =============================================================== CHJ === def _get_geo_observation(self): """ Read lon/lat from IODA-style observation file. Handles: - MetaData group (standard) - fallback to root """ fpath = os.path.join( self.dataset.path, self.dataset.filename ) logger.info(f'''Opening OBS geo file: {fpath}''') # ------------------------- # 1. Try MetaData group (IODA standard) # ------------------------- try: with xr.open_dataset( fpath, group="MetaData", decode_timedelta=True, ) as ds: logger.info("Trying group: MetaData") lon = ds["longitude"].values lat = ds["latitude"].values logger.info("Found lon/lat in MetaData group") except Exception as e: logger.warning(f'''MetaData group read failed: {e}''') # ------------------------- # 2. Fallback: root # ------------------------- with xr.open_dataset( fpath, decode_timedelta=True, ) as ds: logger.info("Falling back to ROOT group") lon_candidates = ["lon", "longitude"] lat_candidates = ["lat", "latitude"] lon_name = next( (v for v in lon_candidates if v in ds.variables), None, ) lat_name = next( (v for v in lat_candidates if v in ds.variables), None, ) if lon_name is None or lat_name is None: raise ValueError("Could not detect lon/lat in OBS file") lon = ds[lon_name].values lat = ds[lat_name].values # ------------------------- # Validation # ------------------------- if lon.ndim != 1 or lat.ndim != 1: raise ValueError("Observation lon/lat must be 1D") if lon.shape != lat.shape: raise ValueError("lon/lat shape mismatch") logger.info(f'''OBS geo size = {lon.size}''') return lat, lon