import logging
import numpy as np
import xarray as xr
import os
import glob
import re
from .utils import extract_tile_prefix
logger = logging.getLogger(__name__)
[docs]
class DataReader:
"""
Read NetCDF data and extract fields (I/O layer only).
"""
def __init__(self, data):
# -------------------------
# Config (immutable)
# -------------------------
self.data = data
self.path = data.path
self.filename = data.filename
self.file_type = data.file_type
self.z_index = data.z_index
self.time_index = data.time_index
# -------------------------
# Runtime (xarray dataset)
# -------------------------
self.xr_ds = None
self.current_file = None
self.current_tile_files = None
self.xr_tile_ds = None
# =============================================================== CHJ ===
def _open_dataset(self, files=None):
"""
Open dataset (lazy loading) with group support
"""
# -------------------------
# Resolve actual file path
# -------------------------
if files is not None:
file_path = files[0]
else:
file_path = os.path.join(
self.path,
self.filename
)
# -------------------------
# Reopen if file changed
# -------------------------
if (
self.xr_ds is None
or self.current_file != file_path
):
# close previous dataset
if self.xr_ds is not None:
try:
self.xr_ds.close()
except Exception:
pass
logger.info(f'''Opening data: {file_path}''')
group = getattr(self.data, "group", None)
try:
if group:
logger.info(f'''Opening group: {group}''')
self.xr_ds = xr.open_dataset(
file_path,
engine="netcdf4",
group=group,
decode_timedelta=True,
)
else:
self.xr_ds = xr.open_dataset(
file_path,
engine="netcdf4",
decode_timedelta=True,
)
except Exception as e:
raise RuntimeError(
f'''Failed to open dataset '''
f'''(group={group}): {e}'''
)
self.current_file = file_path
logger.info(
f'''Dataset var: '''
f'''{list(self.xr_ds.variables)}'''
)
# =============================================================== CHJ ===
[docs]
def get_data(self, varname, fhr=None, rtag=None):
"""
Return raw DataArray (NO styling, NO plotting logic).
"""
logger.debug(f'''data file type = {self.file_type}''')
# -------------------------
# FORECAST
# -------------------------
if self.data.data_kind == "forecast":
if fhr is None:
raise ValueError("Forecast data requires fhr")
files = self.resolve_filenames_for_fhr(fhr)
if self.file_type == "tile":
return self._get_data_tiles(varname, files)
elif self.file_type == "file":
return self._get_data_file(varname, files)
else:
raise ValueError(f'''Unsupported: {self.file_type}''')
# -------------------------
# RESTART
# -------------------------
elif self.data.data_kind == "restart":
if rtag is None:
raise ValueError("Restart data requires rtag")
files = self.resolve_filenames_for_restart(rtag)
if self.file_type == "tile":
return self._get_data_tiles(varname, files)
elif self.file_type == "file":
return self._get_data_file(varname, files)
else:
raise ValueError(f'''Unsupported: {self.file_type}''')
# -------------------------
# OBSERVATION
# -------------------------
elif self.data.data_kind == "observation":
return self._get_data_observation(varname)
# -------------------------
# DEFAULT
# -------------------------
else:
if self.file_type == "tile":
return self._get_data_tiles(varname)
elif self.file_type == "file":
return self._get_data_file(varname)
else:
raise ValueError(f'''Unsupported: {self.file_type}''')
# =============================================================== CHJ ===
[docs]
def get_observation_channels(self, varname):
"""
Return channel dimension + indices
"""
self._open_dataset()
if varname not in self.xr_ds.variables:
raise ValueError(
f'''{varname} not found in dataset. '''
f''' Available: {list(self.xr_ds.variables)}'''
)
da = self.xr_ds[varname]
for d in da.dims:
if d.lower() in ["channel", "chan", "nchan", "band"]:
n = da.sizes[d]
logger.info(f'''Detected channel dim "{d}" size={n}''')
return d, list(range(n))
return None, [None]
# =============================================================== CHJ ===
def _get_data_file(self, varname, files=None):
"""
Read single NetCDF file and return DataArray.
"""
self._open_dataset(files=files)
logger.info(f'''Reading variable: {varname}''')
if varname not in self.xr_ds:
raise ValueError(f'''{varname} not found in dataset''')
da = self.xr_ds[varname]
logger.debug(f'''{varname} dims = {da.dims}''')
logger.debug(f'''{varname} shape = {da.shape}''')
# apply slicing (data-layer only)
da = self._slice_data(da, self.z_index, self.time_index)
# -------------------------
# validation
# -------------------------
if "tile" in da.dims:
if da.ndim != 3:
raise ValueError(
f'''{varname} expected (tile,y,x), giot {da.dims}'''
)
else:
if da.ndim != 2 and da.ndim != 3:
raise ValueError(
f'''{varname} expected 2D/3D, got {da.dims}'''
)
logger.info(f'''{varname} final shape = {da.shape}''')
logger.info(
f'''{varname} min={np.nanmin(da.values)}, '''
f'''max={np.nanmax(da.values)}'''
)
return da
# =============================================================== CHJ ===
def _get_data_tiles(self, varname, files=None):
"""
Read 6-tile NetCDF and return DataArray (tile, y, x).
"""
# ---------------------------------
# FORECAST: files already resolved
# ---------------------------------
if files is not None:
if len(files) != 6:
raise ValueError(f'''Expected 6 tiles, found {len(files)}''')
logger.debug(f'''Tile files: {files}''')
# ---------------------------------
# INCREMENT / ANALYSIS: use prefix
# ---------------------------------
else:
prefix = extract_tile_prefix(self.filename)
pattern = os.path.join(self.path, f'''{prefix}.tile*.nc''')
logger.debug(f'''Tile pattern: {pattern}''')
files = sorted(glob.glob(pattern))
if len(files) != 6:
raise ValueError(f'''Expected 6 tiles, found {len(files)}''')
logger.debug(f'''Files found: {files}''')
# ---------------------------------------
# Reopen tiled datasets if files changed
# ---------------------------------------
if (
self.xr_tile_ds is None
or self.current_tile_files != tuple(files)
):
# close old datasets
if self.xr_tile_ds is not None:
try:
self.xr_tile_ds.close()
except Exception:
pass
logger.info(f'''Opening {len(files)} tile files''')
datasets = []
for f in files:
datasets.append(
xr.open_dataset(
f,
engine="netcdf4",
decode_timedelta=True,
)
)
self.xr_tile_ds = xr.concat(
datasets,
dim="tile"
)
self.current_tile_files = tuple(files)
# ---------------------------------------
# COMMON DATA ACCESS
# ---------------------------------------
ds = self.xr_tile_ds
if varname not in ds:
raise ValueError(f'''{varname} not found in tiled dataset''')
da = ds[varname]
logger.debug(f'''{varname} dims = {da.dims}''')
logger.debug(f'''{varname} shape = {da.shape}''')
da = self._slice_data(da, self.z_index, self.time_index)
if da.ndim != 3:
raise ValueError(
f'''{varname} expected (tile, y, x), got {da.dims}'''
)
vals = da.values
logger.info(f'''{varname} final shape = {da.shape}''')
logger.info(
f'''{varname} min={np.nanmin(vals)}, max={np.nanmax(vals)}'''
)
return da
# =============================================================== CHJ ===
def _get_data_observation(self, varname):
"""
Observation reader:
- auto-detect lon/lat
- supports (Location) or (Location, Channel)
- handles NaNs
"""
self._open_dataset()
logger.info(f'''Reading OBS variable: {varname}''')
# -------------------------
# group handling (ObsValue / MetaData safe access)
# -------------------------
if varname not in self.xr_ds:
# try group search (common in obs files)
for gname, group in self.xr_ds.groups.items():
if varname in group.variables:
da = group[varname]
break
else:
raise ValueError(f'''{varname} not found in any group''')
else:
da = self.xr_ds[varname]
logger.info(f'''{varname} dims = {da.dims}, shape = {da.shape}''')
# -------------------------
# NaN / fill value handling
# -------------------------
da = da.where(np.isfinite(da))
return da
# =============================================================== CHJ ===
@staticmethod
def _slice_data(da, z_index=None, time_index=0):
"""
Apply time + vertical slicing (data-layer only).
"""
# -------------------------
# time slicing
# -------------------------
time_dim = next((d for d in ["time", "Time"] if d in da.dims), None)
if time_dim is not None:
if da.sizes.get(time_dim, 1) > 1:
logger.debug(
f'''{time_dim} > 1, selecting index {time_index}'''
)
da = da.isel({time_dim: time_index})
# -------------------------
# vertical slicing
# -------------------------
z_dims = [
"pfull", "zaxis_1", "zaxis_2", "zaxis_3",
"zaxis_4", "lev", "level", "depth", "z"
]
z_dim = next((d for d in z_dims if d in da.dims), None)
if z_dim is not None and z_index is not None:
da = da.isel({z_dim: z_index})
return da
# =============================================================== CHJ ===
[docs]
def detect_forecast_hours(self):
"""
Detect forecast hours from filename pattern.
Works for:
- f*
- any glob pattern containing fXXX
"""
pattern = self.filename
# Convert pattern to glob
glob_pattern = pattern
# If user used [1-6], reduce to tile1 for detection
glob_pattern = re.sub(r"\[1-6\]", "1", glob_pattern)
search_path = os.path.join(self.path, glob_pattern)
logger.info(f'''Detecting forecast files: {search_path}''')
files = glob.glob(search_path)
if not files:
raise ValueError(f'''No files found for pattern: {search_path}''')
fhrs = set()
for f in files:
fname = os.path.basename(f)
# robust match: f000, f012, f120 etc.
match = re.search(r"\.f(\d{2,4})\.", fname)
if match:
fhrs.add(match.group(1))
if not fhrs:
raise ValueError("Could not detect forecast hours from filenames")
fhrs = sorted(fhrs)
logger.info(f'''Detected forecast hours: {fhrs}''')
return fhrs
# =============================================================== CHJ ===
[docs]
def resolve_filenames_for_fhr(self, fhr):
"""
Resolve forecast filenames.
Supports:
- FV3 tiled forecasts
- MOM6/CICE/WW3 single-file forecasts
"""
pattern = self.filename.replace("*", fhr)
# ==========================================================
# FV3 tiled forecasts
# ==========================================================
if "tile1" in pattern:
files = [
os.path.join(
self.path,
pattern.replace("tile1", f"tile{i}")
)
for i in range(1, 7)
]
elif re.search(r"tile\d+", pattern):
files = [
os.path.join(
self.path,
re.sub(r"tile\d+", f"tile{i}", pattern)
)
for i in range(1, 7)
]
# ==========================================================
# Single-file forecasts (MOM6, WW3, CICE, etc.)
# ==========================================================
else:
files = [
os.path.join(self.path, pattern)
]
# -------------------------
# Validate existence
# -------------------------
missing = [f for f in files if not os.path.exists(f)]
if missing:
raise FileNotFoundError(
f'''Missing forecast files for f{fhr}: {missing}'''
)
return files
# =============================================================== CHJ ===
# =============================================================== CHJ ===
[docs]
def resolve_filenames_for_restart(self, tag):
pattern = self.filename.replace("*", tag)
if "tile1" in pattern:
files = [
os.path.join(
self.path,
pattern.replace("tile1", f'''tile{i}''')
)
for i in range(1, 7)
]
else:
raise ValueError(f'''Invalid restart tile pattern: {pattern}''')
return files
# =============================================================== CHJ ===
[docs]
def close(self):
if self.xr_ds is not None:
try:
self.xr_ds.close()
finally:
self.xr_ds = None
if self.xr_tile_ds is not None:
try:
self.xr_tile_ds.close()
finally:
self.xr_tile_ds = None