From ce01aaf8a7eb2ffdf9140050e55b04a100a8801e Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Thu, 16 Oct 2025 16:20:06 +0530 Subject: [PATCH 01/13] feat: baseline processor init + fixes with nnjai --- graph_weather/data/__init__.py | 2 +- graph_weather/data/bufr_process.py | 184 +++++++++++++++++++++++++++ graph_weather/data/nnjaai.py | 194 +++++++++++++++++++++++++++++ tests/test_nnjai.py | 2 +- 4 files changed, 380 insertions(+), 2 deletions(-) create mode 100644 graph_weather/data/bufr_process.py create mode 100644 graph_weather/data/nnjaai.py diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index d67a79e4..a47080ad 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,5 +1,5 @@ """Dataloaders and data processing utilities""" from .anemoi_dataloader import AnemoiDataset -from .nnja_ai import SensorDataset +from .nnjaai import SensorDataset from .weather_station_reader import WeatherStationReader diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py new file mode 100644 index 00000000..1341480b --- /dev/null +++ b/graph_weather/data/bufr_process.py @@ -0,0 +1,184 @@ +from dataclasses import dataclass, field +from typing import Optional, Callable, Any , List, Dict +import numpy as np +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("Bufr Processor") + +from torch.utils.data import DataLoader, IterableDataset + +try: + import eccodes +except ImportError: + raise ImportError("eccodes not installed. Install with: `pip install eccodes`") + + +@dataclass +class FieldMapping: + """Maps a BUFR source field to NNJA-AI output field.""" + source_name: str + output_name: str + dtype: type + transform_fn: Optional[Callable] = None + required: bool = True + description: str = "" + + def apply(self, value: Any) -> Any: + """Apply transformation to source value.""" + if value is None: + return None + if self.transform_fn: + return self.transform_fn(value) + return value + + +class NNJA_Schema: + """ + Defines the canonical NNJA-AI schema that all BUFR data maps to + + Mimics NNJA-AI's xarray format with standardized coordinates and variables. + Coordinate system matches NNJA-AI: + - OBS_TIMESTAMP: observation time (ns precision) + - LAT: latitude + - LON: longitude + """ + COORDINATES = { + 'OBS_TIMESTAMP': 'datetime64[ns]', + 'LAT': 'float32', + 'LON': 'float32', + } + + VARIABLES = { + 'temperature': 'float32', + 'pressure': 'float32', + 'relative_humidity': 'float32', + 'u_wind': 'float32', + 'v_wind': 'float32', + 'dew_point': 'float32', + 'height': 'float32', + } + + ATTRIBUTES = { + 'source': 'DATA_SOURCE', + 'qc_flag': 'int8', + 'processing_timestamp': 'datetime64[ns]', + } + @classmethod + def to_xarray_schema(cls) -> Dict[str, str]: + """Get full schema as dict for xarray construction.""" + return {**cls.COORDINATES, **cls.VARIABLES, **cls.ATTRIBUTES} + + @classmethod + def get_coordinate_names(cls) -> List[str]: + """Get list of coordinate names.""" + return list(cls.COORDINATES.keys()) + + @classmethod + def validate_data(cls, data: Dict[str, np.ndarray]) -> bool: + """Check if data has required NNJA coordinates.""" + required_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] + return all(coord in data for coord in required_coords) + + + def __init__(self): + pass + +class DataSourceSchema: + """ + Abstract base for source-specific BUFR schema mappings. + Defines how BUFR fields from a specific source (ADPUPA, CrIS, etc.) + map to NNJA-AI canonical format. + """ + + + source_name: str = "unknown" + + def __init__(self): + self.field_mappings: Dict[str, FieldMapping] = {} + self._build_mappings() + self._validate() + + def _build_mappings(self): + """ + Override in subclasses to define BUFR → NNJA field mappings. + Example: + self.field_mappings['T'] = FieldMapping( + source_name='T', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15, # K to C + description='Temperature in Celsius' + ) + """ + raise NotImplementedError("Subclasses must implement _build_mappings()") + + def _validate(self): + """Ensure all required NNJA coordinates are mapped.""" + required = ['OBS_TIMESTAMP', 'LAT', 'LON'] + mapped_outputs = {m.output_name for m in self.field_mappings.values()} + missing = [r for r in required if r not in mapped_outputs] + if missing: + logger.warning( + f"{self.source_name} schema missing required outputs: {missing}" + ) + + def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: + """ + Transform raw BUFR message to NNJA-AI format. + + Args: + bufr_message: Decoded BUFR message (dict of field → values) + + Returns: + Observation in NNJA format (dict matching NNJASchema) + """ + mapped = {} + + for field_map in self.field_mappings.values(): + if field_map.source_name in bufr_message: + raw_value = bufr_message[field_map.source_name] + try: + value = field_map.apply(raw_value) + mapped[field_map.output_name] = value + except Exception as e: + logger.warning( + f"Error transforming {field_map.source_name}: {e}" + ) + mapped[field_map.output_name] = None + + return mapped + + def get_variable_list(self) -> List[str]: + """Get list of NNJA variables this source provides.""" + return list(set(m.output_name for m in self.field_mappings.values())) + +class _BUFRIterableDataset(IterableDataset): + """Internal IterableDataset wrapper for PyTorch DataLoader.""" + + def __init__(self, bufr_loader: BUFRDataLoader): + self.bufr_loader = bufr_loader + + def __iter__(self): + return iter(self.bufr_loader) + + + +class BUFR_dataloader: + def __init__(self,dataset,batch_size): + """ + Args: + -> dataset : str (path) + -> batch_size : int + """ + self.dataset = dataset + self.batch_size = batch_size + + def decoder(self): + pass + def map_to_nnjai_schema(self): + pass + + diff --git a/graph_weather/data/nnjaai.py b/graph_weather/data/nnjaai.py new file mode 100644 index 00000000..71650489 --- /dev/null +++ b/graph_weather/data/nnjaai.py @@ -0,0 +1,194 @@ +""" +Dynamic loader for NNJA-AI datasets with support for primary descriptors and data variables. + +Features: +- Automatically loads primary descriptors + primary data by default +- Supports custom variable selection +- Can load all variables when requested +- Returns xarray.Dataset with time as the only coordinate +- Optimized for performance with direct xarray access + +""" + +import numpy as np +import xarray as xr +from torch.utils.data import Dataset + +try: + from nnja_ai.catalog import DataCatalog +except ImportError: + raise ImportError("NNJA-AI library not installed. Install with: " "`pip install nnja-ai`") + + +def _classify_variable(nnja_var) -> str: + """Return category of a variable using attributes or repr fallback.""" + # First try to get explicit attributes + if hasattr(nnja_var, "category"): + return nnja_var.category + if hasattr(nnja_var, "role"): + return nnja_var.role + + # Fallback to string representation + tag = repr(nnja_var).lower() + if "primary_descriptor" in tag or "primary descriptor" in tag: + return "primary_descriptor" + if "primary_data" in tag or "primary data" in tag: + return "primary_data" + return "other" + + +def load_nnja_dataset( + dataset_name: str, + time=None, + variables: list[str] | None = None, + load_all: bool = False, +) -> xr.Dataset: + """ + Load a NNJA dataset as an xarray.Dataset with time as the only coordinate. + + Args: + dataset_name: Name of NNJA dataset to load + time: Time selection (single timestamp, slice, or None) + variables: Specific variables to load (overrides default) + load_all: Load all available variables in the dataset + + Returns: + xarray.Dataset with only 'time' dimension/coordinate + """ + try: + cat = DataCatalog() + ds_meta = cat[dataset_name] + ds_meta.load_manifest() + except KeyError as e: + raise ValueError(f"Dataset '{dataset_name}' not found in catalog") from e + + vars_dict = ds_meta.variables + if load_all: + vars_to_load = list(vars_dict.keys()) + elif variables: + # Validate requested variables + invalid_vars = [v for v in variables if v not in vars_dict] + if invalid_vars: + raise ValueError(f"Invalid variables requested: {invalid_vars}") + vars_to_load = variables + else: + # Default: primary descriptors + primary data + primary = [ + name + for name, v in vars_dict.items() + if _classify_variable(v) + in ( + "primary_descriptor", + "primary_data", + "primary descriptor", + "primary data", + ) + ] + vars_to_load = primary + + try: + df = ds_meta.sel(time=time, variables=vars_to_load).load_dataset( + backend="pandas", engine="pyarrow" + ) + except Exception as e: + raise RuntimeError(f"Error loading dataset '{dataset_name}': {str(e)}") from e + + xrds = df.to_xarray() + + # Standardize coordinate names + rename_map = {"OBS_TIMESTAMP": "time", "LAT": "latitude", "LON": "longitude"} + for coord_var in rename_map: + if coord_var in vars_dict and coord_var not in vars_to_load: + vars_to_load.append(coord_var) + xrds = xrds.rename({k: v for k, v in rename_map.items() if k in xrds}) + + # Ensure 'time' coordinate exists + if "time" not in xrds and "OBS_DATE" in xrds: + xrds = xrds.rename({"OBS_DATE": "time"}) + + # Handle time conversion if needed + if "time" in xrds and not np.issubdtype(xrds.time.dtype, np.datetime64): + xrds["time"] = xrds.time.astype("datetime64[ns]") + + # If time is not a dimension but 'obs' is, swap + if "time" in xrds and "obs" in xrds.dims and "time" not in xrds.dims: + xrds = xrds.swap_dims({"obs": "time"}) + if "obs" in xrds.coords: + xrds = xrds.reset_coords("obs", drop=True) + + if "time" in xrds and "time" not in xrds.coords: + xrds = xrds.set_coords("time") + + # Flatten extra dimensions into time as may encounter an extra "index" dimension + # Ensures output is always 1D along "time" + extra_dims = [d for d in xrds.dims if d != "time"] + if extra_dims: + time_values = xrds.time.values if "time" in xrds else None + xrds = xrds.stack(sample=tuple(extra_dims)) + xrds = xrds.reset_index("sample") + + # Rename to time and restore original time values + if "sample" in xrds.dims: + xrds = xrds.swap_dims({"sample": "time"}) + if "sample" in xrds.coords: + xrds = xrds.reset_coords("sample", drop=True) + if time_values is not None: + xrds["time"] = ("time", time_values) + + if "time" not in xrds.dims: + raise RuntimeError("Failed to establish 'time' dimension in output dataset") + + return xrds + + +class SensorDataset(Dataset): + """PyTorch Dataset wrapper for NNJA-AI datasets with optimized access.""" + + def __init__(self, dataset_name, time=None, variables=None, load_all=False): + """Initialize dataset loader. + + Args: + dataset_name: Name of NNJA dataset to load + time: Time selection (single timestamp or slice) + variables: Specific variables to load + load_all: If True, loads all available variables + """ + self.dataset_name = dataset_name + self.time = time + + self.xrds = load_nnja_dataset( + dataset_name, time=time, variables=variables, load_all=load_all + ) + + # Store for efficient access + self.variables = list(self.xrds.data_vars.keys()) + self.time_index = self.xrds.time.values + + def __len__(self): + return self.xrds.sizes["time"] + + def __getitem__(self, idx): + """Direct xarray access without DataFrame conversion.""" + time_point = self.time_index[idx] + return {var: self.xrds[var].sel(time=time_point).item() for var in self.variables} + + +class NNJATorchDataset(Dataset): + """Adapter for torch Dataset directly from xarray.""" + + def __init__(self, xrds): + """Initialize adapter. + + Args: + xrds: xarray Dataset to convert + """ + self.ds = xrds + self.vars = list(xrds.data_vars.keys()) + self.time_index = xrds.time.values + + def __len__(self): + return self.ds.sizes["time"] + + def __getitem__(self, idx): + time_point = self.time_index[idx] + return {var: self.ds[var].sel(time=time_point).item() for var in self.vars} diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index b609ca18..c6985118 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -11,7 +11,7 @@ import pytest import xarray as xr -from graph_weather.data.nnja_ai import ( +from graph_weather.data.nnjaai import ( NNJATorchDataset, SensorDataset, _classify_variable, From 0b5410d250629a245ac3149016f04f2b90f03227 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Thu, 16 Oct 2025 21:38:12 +0530 Subject: [PATCH 02/13] feat : introduced schemas cris, adpupa pipelines for nnjaai mapping --- graph_weather/data/bufr_process.py | 116 ++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 1341480b..3c1581a0 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -2,6 +2,7 @@ from typing import Optional, Callable, Any , List, Dict import numpy as np import logging +import pandas as pd logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -155,17 +156,105 @@ def get_variable_list(self) -> List[str]: """Get list of NNJA variables this source provides.""" return list(set(m.output_name for m in self.field_mappings.values())) -class _BUFRIterableDataset(IterableDataset): - """Internal IterableDataset wrapper for PyTorch DataLoader.""" - - def __init__(self, bufr_loader: BUFRDataLoader): - self.bufr_loader = bufr_loader - - def __iter__(self): - return iter(self.bufr_loader) - +class ADPUPA_schema(DataSourceSchema): + """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI.""" + source_name = "ADPUPA" + def _build_mappings(self): + self.field_mappings= { + 'latitude' : FieldMapping( + source_name='latitude', + output_name='LAT', + dtype=float, + description='Station latitude' + ), + 'longitude' : FieldMapping( + source_name='longitude', + output_name='LON', + dtype=float, + description='Station longitude' + ), + 'obsTime' : FieldMapping( + source_name='obsTime', + output_name='OBS_TIME', + description='datetime64[ns]', + transform_fn=lambda x: pd.Timestamp(x).value if isinstance(x, str) else x, + description='Observation timestamp' + ), + 'airTemperature' : FieldMapping( + source_name='airTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Temperature in Celsius' + ), + 'pressure' : FieldMapping( + source_name='pressure', + output_name='pressure', + dtype=float, + description='Pressure in Pa' + ), + 'height' : FieldMapping( + source_name='height', + output_name='height', + dtype=float, + description='Height above sealevel in m' + ), + 'dewpointTemperature' : FieldMapping( + source_name='dewpointTemperature', + output_name='dew_point', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Dew point in Celsius' + ), + 'windU' : FieldMapping( + source_name='windU', + output_name='u_wind', + dtype=float, + description='U-component wind (m/s)' + ), + 'windV': FieldMapping( + source_name='windV', + output_name='v_wind', + dtype=float, + description='V-component wind (m/s)' + ) + } +class CRIS_Schema(DataSourceSchema): + """CrIS (satellite hyperspectral) BUFR schema mapping to NNJA-AI.""" + + source_name = "CrIS" + def _build_mappings(self): + self.field_mappings = { + 'latitude' : FieldMapping( + source_name='lat', + output_name='LAT', + dtype=float + ), + 'longitude' : FieldMapping( + source_name='lon', + output_name='LON', + dtype=float + ), + 'obsTime' : FieldMapping( + source_name='obsTime', + output_name='OBS_TIMESTAMP', + dtype='datetime64[ns]', + transform_fn=lambda x: pd.Timestamp(x).value, + ), + 'retrievedTemperature': FieldMapping( + source_name='retrievedTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15, + ), + 'retrievedPressure': FieldMapping( + source_name='retrievedPressure', + output_name='pressure', + dtype=float, + ), + } class BUFR_dataloader: def __init__(self,dataset,batch_size): """ @@ -182,3 +271,12 @@ def map_to_nnjai_schema(self): pass +class _BUFRIterableDataset(IterableDataset): + """Internal IterableDataset wrapper for PyTorch DataLoader.""" + + def __init__(self, bufr_loader: BUFR_dataLoader): + self.bufr_loader = bufr_loader + + def __iter__(self): + return iter(self.bufr_loader) + From b15b6e912f871a4b88b4e96891223608f7ce00bc Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 24 Oct 2025 14:03:30 +0530 Subject: [PATCH 03/13] feat : initial bufr processor with adpupa and cris support --- graph_weather/__init__.py | 2 +- graph_weather/data/bufr_process.py | 231 ++++++++++++++++++- graph_weather/data/nnja_ai.py | 194 ---------------- graph_weather/data/weather_station_reader.py | 12 +- 4 files changed, 229 insertions(+), 210 deletions(-) delete mode 100755 graph_weather/data/nnja_ai.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index b33e23cd..73118340 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,6 +1,6 @@ """Main import for the complete models""" -from .data.nnja_ai import SensorDataset +from .data.nnjaai import SensorDataset from .data.weather_station_reader import WeatherStationReader from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 3c1581a0..8ad0d8ce 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -3,6 +3,8 @@ import numpy as np import logging import pandas as pd +import xarray as xr +from pathlib import Path logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -83,10 +85,6 @@ def validate_data(cls, data: Dict[str, np.ndarray]) -> bool: required_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] return all(coord in data for coord in required_coords) - - def __init__(self): - pass - class DataSourceSchema: """ Abstract base for source-specific BUFR schema mappings. @@ -218,10 +216,27 @@ def _build_mappings(self): output_name='v_wind', dtype=float, description='V-component wind (m/s)' + ), + 'stationId': FieldMapping( + source_name='stationId', + output_name='station_id', + dtype=str, + required=False, + description='Station identifier' ) } + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) + -class CRIS_Schema(DataSourceSchema): +class CRIS_schema(DataSourceSchema): """CrIS (satellite hyperspectral) BUFR schema mapping to NNJA-AI.""" source_name = "CrIS" @@ -254,27 +269,225 @@ def _build_mappings(self): output_name='pressure', dtype=float, ), + 'sourceZenithAngle': FieldMapping( + source_name='sensorZenithAngle', + output_name='sensor_zenith_angle', + dtype=float, + required=False, + description='Sensor zenith angle' + ), + 'qualityFlags' : FieldMapping( + source_name='qualityFlags', + output_name='qc_flag', + dtype=int, + description='Quality control flags' + ) } -class BUFR_dataloader: - def __init__(self,dataset,batch_size): + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) + +class BUFR_processsor: + """ + Low-level BUFR file decoder. + Handles binary BUFR format decoding using eccodes library. + """ + def __init__(self , schema : DataSourceSchema): + """ + Args: + -> schema : DataSourceSchema instance + """ + if not isinstance(schema,DataSourceSchema): + raise TypeError('schema must be of DataSourceSchema instance') + + self.schema = schema + + def decoder_bufr_files(self, filepath) -> List[Dict[str,any]]: + """ + Decode all messages from BUFR file. + + Args: + -> filepath: Path to BUFR file + """ + msgs = [] + filepath = Path(filepath) + + if not filepath.exists(): + raise FileNotFoundError(f"BUFR file not found: {filepath}") + + try: + with open(filepath, 'rb') as f: + while True: + bufr_id = eccodes.bufr_new_from_file(f) + if bufr_id is None: + break + + try: + eccodes.codes_set(bufr_id, 'unpack', 1) + msg = {} + iterator = eccodes.bufr_keys_iterator(bufr_id) + while eccodes.bufr_keys_iterator_next(iterator): + key = eccodes.bufr_keys_iterator_get_name(iterator) + try: + value = eccodes.codes_get_string(bufr_id, key) + msg[key] = value + except (eccodes.KeyValueNotFoundError, eccodes.CodesInternalError): + try: + value = eccodes.codes_get_double(bufr_id, key) + msg[key] = value + except (eccodes.KeyValueNotFoundError, eccodes.CodesInternalError): + pass + eccodes.codes_bufr_keys_iterator_delete(iterator) + msgs.append(msg) + finally: + eccodes.codes_release(bufr_id) + + except Exception as e: + logger.error(f"Error decoding BUFR file {filepath}: {e}") + raise + + logger.info(f"Decoded {len(msgs)} messages from {filepath}") + return msgs + + def process_files_to_dataframe(self, filepath : str)-> pd.DataFrame: + """ + Decode BUFR file and map to NNJA schema, return as DataFrame. + + Args: + -> filepath: Path to BUFR file + + Returns: + -> pandas DataFrame in NNJA-AI format + """ + + raw_msgs = self.decoder_bufr_files(filepath=filepath) + + transformed = [] + + for msg in raw_msgs: + mapped = self.schema.map_observations(msg) + if mapped: + transformed.append(mapped) + + if not transformed: + logger.warning(f"No valid observations found in {filepath}") + return pd.DataFrame() + + + df = pd.DataFrame(transformed) + + for col in df.columns: + if col in NNJA_Schema.COORDINATES: + dtype = NNJA_Schema.COORDINATES[col] + if 'datetime' in dtype: + df[col] = pd.to_datetime(df[col]) + else: + df[col] = df[col].astype(dtype.split('[')[0] if '[' in dtype else dtype) + if not NNJA_Schema.validate_data(df.to_dict(orient='list')): + logger.warning(f"DataFrame missing required NNJA coordinates from {filepath}") + + return df + + def process_files_to_xarray(self, filepath : str) -> xr.Dataset: + """ + Process BUFR file to xarray Dataset in NNJA-AI format. + + Args: + -> filepath: Path to BUFR file + + Returns: + -> xarray Dataset in NNJA-AI format + """ + df = self.process_files_to_dataframe(filepath=filepath) + + if df.empty: + logger.warning(f"No data to convert to xarray from {filepath}") + return xr.Dataset() + + data_vars = {} + for col in df.columns: + if col not in ['OBS_TIMESTAMP', 'LAT', 'LON']: + data_vars[col] = (['observation'], df[col].values) + + ds = xr.Dataset( + data_vars=data_vars, + coords={ + 'obs' : df.index , + 'time' : ('obs', df['obs'].values), + 'lat' : ('obs', df['LAT'].values), + 'lon' : ('obs', df['LON'].values), + } + ) + ds.attrs['source'] = self.schema.source_name + ds.attrs['processing_timestamp'] = pd.Timestamp.now().isoformat() + ds.attrs['num_observations'] = len(df) + + return ds + + def process_files_to_parquet(self, filepath: str, output_path: str)->None: + """ + Process BUFR file and save as Parquet in NNJA-AI format. + + Args: + -> filepath: Path to BUFR file + -> output_path: Path for output Parquet file + """ + df = self.process_file_to_dataframe(filepath=filepath) + + if not df.empty: + df.to_parquet(output_path, index=False) + logger.info(f"Saved {len(df)} observations to {output_path}") + else: + logger.warning(f"No data to save for {filepath}") + + +class BUFR_dataloader: + + SCHEMA_REGISTRY={ + 'ADPUPA': ADPUPA_schema, + 'CrIS': CRIS_schema, + } + def __init__(self,dataset:str,batch_size:int=32,schema_name: Optional[str] = None): """ Args: -> dataset : str (path) -> batch_size : int + -> schema_name : Data source name ('ADPUPA', 'CrIS', etc.) + If None, attempts to infer from filename """ self.dataset = dataset self.batch_size = batch_size + if schema_name is None: + schema_name = self._infer_schema_from_path(dataset) + + if schema_name not in self.SCHEMA_REGISTRY: + raise ValueError( + f'Unknown schema "{schema_name}"\nAvailable : {list(self.SCHEMA_REGISTRY)}' + ) + def decoder(self): pass def map_to_nnjai_schema(self): pass - + def to_dataframe(self): + pass + def to_parquet(self): + pass + def __iter__(self): + pass + def get_dataloader(self): + pass class _BUFRIterableDataset(IterableDataset): """Internal IterableDataset wrapper for PyTorch DataLoader.""" - def __init__(self, bufr_loader: BUFR_dataLoader): + def __init__(self, bufr_loader: BUFR_dataloader): self.bufr_loader = bufr_loader def __iter__(self): diff --git a/graph_weather/data/nnja_ai.py b/graph_weather/data/nnja_ai.py deleted file mode 100755 index 99cf95cf..00000000 --- a/graph_weather/data/nnja_ai.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Dynamic loader for NNJA-AI datasets with support for primary descriptors and data variables. - -Features: -- Automatically loads primary descriptors + primary data by default -- Supports custom variable selection -- Can load all variables when requested -- Returns xarray.Dataset with time as the only coordinate -- Optimized for performance with direct xarray access - -""" - -import numpy as np -import xarray as xr -from torch.utils.data import Dataset - -try: - from nnja import DataCatalog -except ImportError: - raise ImportError("NNJA-AI library not installed. Install with: " "`pip install nnja-ai`") - - -def _classify_variable(nnja_var) -> str: - """Return category of a variable using attributes or repr fallback.""" - # First try to get explicit attributes - if hasattr(nnja_var, "category"): - return nnja_var.category - if hasattr(nnja_var, "role"): - return nnja_var.role - - # Fallback to string representation - tag = repr(nnja_var).lower() - if "primary_descriptor" in tag or "primary descriptor" in tag: - return "primary_descriptor" - if "primary_data" in tag or "primary data" in tag: - return "primary_data" - return "other" - - -def load_nnja_dataset( - dataset_name: str, - time=None, - variables: list[str] | None = None, - load_all: bool = False, -) -> xr.Dataset: - """ - Load a NNJA dataset as an xarray.Dataset with time as the only coordinate. - - Args: - dataset_name: Name of NNJA dataset to load - time: Time selection (single timestamp, slice, or None) - variables: Specific variables to load (overrides default) - load_all: Load all available variables in the dataset - - Returns: - xarray.Dataset with only 'time' dimension/coordinate - """ - try: - cat = DataCatalog() - ds_meta = cat[dataset_name] - ds_meta.load_manifest() - except KeyError as e: - raise ValueError(f"Dataset '{dataset_name}' not found in catalog") from e - - vars_dict = ds_meta.variables - if load_all: - vars_to_load = list(vars_dict.keys()) - elif variables: - # Validate requested variables - invalid_vars = [v for v in variables if v not in vars_dict] - if invalid_vars: - raise ValueError(f"Invalid variables requested: {invalid_vars}") - vars_to_load = variables - else: - # Default: primary descriptors + primary data - primary = [ - name - for name, v in vars_dict.items() - if _classify_variable(v) - in ( - "primary_descriptor", - "primary_data", - "primary descriptor", - "primary data", - ) - ] - vars_to_load = primary - - try: - df = ds_meta.sel(time=time, variables=vars_to_load).load_dataset( - backend="pandas", engine="pyarrow" - ) - except Exception as e: - raise RuntimeError(f"Error loading dataset '{dataset_name}': {str(e)}") from e - - xrds = df.to_xarray() - - # Standardize coordinate names - rename_map = {"OBS_TIMESTAMP": "time", "LAT": "latitude", "LON": "longitude"} - for coord_var in rename_map: - if coord_var in vars_dict and coord_var not in vars_to_load: - vars_to_load.append(coord_var) - xrds = xrds.rename({k: v for k, v in rename_map.items() if k in xrds}) - - # Ensure 'time' coordinate exists - if "time" not in xrds and "OBS_DATE" in xrds: - xrds = xrds.rename({"OBS_DATE": "time"}) - - # Handle time conversion if needed - if "time" in xrds and not np.issubdtype(xrds.time.dtype, np.datetime64): - xrds["time"] = xrds.time.astype("datetime64[ns]") - - # If time is not a dimension but 'obs' is, swap - if "time" in xrds and "obs" in xrds.dims and "time" not in xrds.dims: - xrds = xrds.swap_dims({"obs": "time"}) - if "obs" in xrds.coords: - xrds = xrds.reset_coords("obs", drop=True) - - if "time" in xrds and "time" not in xrds.coords: - xrds = xrds.set_coords("time") - - # Flatten extra dimensions into time as may encounter an extra "index" dimension - # Ensures output is always 1D along "time" - extra_dims = [d for d in xrds.dims if d != "time"] - if extra_dims: - time_values = xrds.time.values if "time" in xrds else None - xrds = xrds.stack(sample=tuple(extra_dims)) - xrds = xrds.reset_index("sample") - - # Rename to time and restore original time values - if "sample" in xrds.dims: - xrds = xrds.swap_dims({"sample": "time"}) - if "sample" in xrds.coords: - xrds = xrds.reset_coords("sample", drop=True) - if time_values is not None: - xrds["time"] = ("time", time_values) - - if "time" not in xrds.dims: - raise RuntimeError("Failed to establish 'time' dimension in output dataset") - - return xrds - - -class SensorDataset(Dataset): - """PyTorch Dataset wrapper for NNJA-AI datasets with optimized access.""" - - def __init__(self, dataset_name, time=None, variables=None, load_all=False): - """Initialize dataset loader. - - Args: - dataset_name: Name of NNJA dataset to load - time: Time selection (single timestamp or slice) - variables: Specific variables to load - load_all: If True, loads all available variables - """ - self.dataset_name = dataset_name - self.time = time - - self.xrds = load_nnja_dataset( - dataset_name, time=time, variables=variables, load_all=load_all - ) - - # Store for efficient access - self.variables = list(self.xrds.data_vars.keys()) - self.time_index = self.xrds.time.values - - def __len__(self): - return self.xrds.sizes["time"] - - def __getitem__(self, idx): - """Direct xarray access without DataFrame conversion.""" - time_point = self.time_index[idx] - return {var: self.xrds[var].sel(time=time_point).item() for var in self.variables} - - -class NNJATorchDataset(Dataset): - """Adapter for torch Dataset directly from xarray.""" - - def __init__(self, xrds): - """Initialize adapter. - - Args: - xrds: xarray Dataset to convert - """ - self.ds = xrds - self.vars = list(xrds.data_vars.keys()) - self.time_index = xrds.time.values - - def __len__(self): - return self.ds.sizes["time"] - - def __getitem__(self, idx): - time_point = self.time_index[idx] - return {var: self.ds[var].sel(time=time_point).item() for var in self.vars} diff --git a/graph_weather/data/weather_station_reader.py b/graph_weather/data/weather_station_reader.py index 769986b3..a11b20b8 100644 --- a/graph_weather/data/weather_station_reader.py +++ b/graph_weather/data/weather_station_reader.py @@ -27,13 +27,13 @@ logger = logging.getLogger("WeatherStationReader") # Try importing synopticpy, but don't require it -try: - from synopticpy import Synoptic +# try: +# from synoptic import Synoptic - SYNOPTIC_AVAILABLE = True -except ImportError: - SYNOPTIC_AVAILABLE = False - logger.warning("SynopticPy package not installed, synoptic functionality won't be available") +# SYNOPTIC_AVAILABLE = True +# except ImportError: +# SYNOPTIC_AVAILABLE = False +# logger.warning("SynopticPy package not installed, synoptic functionality won't be available") class WeatherStationReader: From c3e2e4b08ca60a6d9e330795ce8aeca1d66f763d Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 24 Oct 2025 14:10:24 +0530 Subject: [PATCH 04/13] feat : bufr func checks --- graph_weather/data/bufr_process.py | 73 +++++++++++++++++++----------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 8ad0d8ce..1e3fe9a7 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Callable, Any , List, Dict +from typing import Optional, Callable, Any , List, Dict, Iterator import numpy as np import logging import pandas as pd @@ -451,39 +451,58 @@ class BUFR_dataloader: SCHEMA_REGISTRY={ 'ADPUPA': ADPUPA_schema, 'CrIS': CRIS_schema, - } - def __init__(self,dataset:str,batch_size:int=32,schema_name: Optional[str] = None): + } + def __init__(self, filepath: str, schema_name: Optional[str] = None): """ - Args: - -> dataset : str (path) - -> batch_size : int - -> schema_name : Data source name ('ADPUPA', 'CrIS', etc.) - If None, attempts to infer from filename + Args: + filepath: Path to BUFR file or directory + schema_name: Data source name ('ADPUPA', 'CrIS', etc.) """ - self.dataset = dataset - self.batch_size = batch_size + self.filepath = Path(filepath) + self.schema_name = schema_name or self._infer_schema_from_path() - if schema_name is None: - schema_name = self._infer_schema_from_path(dataset) - - if schema_name not in self.SCHEMA_REGISTRY: + if self.schema_name not in self.SCHEMA_REGISTRY: raise ValueError( - f'Unknown schema "{schema_name}"\nAvailable : {list(self.SCHEMA_REGISTRY)}' + f'Unknown schema "{self.schema_name}". Available: {list(self.SCHEMA_REGISTRY.keys())}' ) - def decoder(self): - pass - def map_to_nnjai_schema(self): - pass - def to_dataframe(self): - pass - def to_parquet(self): - pass - def __iter__(self): - pass - def get_dataloader(self): - pass + self.schema = self.SCHEMA_REGISTRY[self.schema_name]() + self.processor = BUFR_processsor(self.schema) + def _infer_schema_from_path(self) -> str: + """Infer schema from filename or path patterns.""" + filename = self.filepath.name.lower() + + if 'adpupa' in filename or 'raob' in filename or 'sound' in filename: + return 'ADPUPA' + elif 'cris' in filename: + return 'CrIS' + elif 'iasi' in filename: + return 'IASI' + elif 'atms' in filename: + return 'ATMS' + else: + # Default to ADPUPA for now + logger.warning(f"Could not infer schema from {filename}, defaulting to ADPUPA") + return 'ADPUPA' + + def to_dataframe(self) -> pd.DataFrame: + """Process BUFR file to DataFrame.""" + return self.processor.process_file_to_dataframe(str(self.filepath)) + + def to_xarray(self) -> xr.Dataset: + """Process BUFR file to xarray Dataset.""" + return self.processor.process_file_to_xarray(str(self.filepath)) + def to_parquet(self, output_path: str) -> None: + """Process BUFR file to Parquet format.""" + self.processor.process_file_to_parquet(str(self.filepath), output_path) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + """Iterate over observations in the BUFR file.""" + messages = self.processor.decode_bufr_file(str(self.filepath)) + for msg in messages: + yield self.schema.map_observation(msg) + class _BUFRIterableDataset(IterableDataset): """Internal IterableDataset wrapper for PyTorch DataLoader.""" From 0f454f989765d9d4c8745b731ceb38c81709fe48 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 24 Oct 2025 14:38:23 +0530 Subject: [PATCH 05/13] feat : initiated pytest for bufr_process --- graph_weather/data/__init__.py | 1 + graph_weather/data/bufr_process.py | 3 +- tests/test_bufr_process.py | 481 +++++++++++++++++++++++++++++ 3 files changed, 483 insertions(+), 2 deletions(-) create mode 100644 tests/test_bufr_process.py diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index a47080ad..79cf9c88 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -3,3 +3,4 @@ from .anemoi_dataloader import AnemoiDataset from .nnjaai import SensorDataset from .weather_station_reader import WeatherStationReader +from .bufr_process import BUFR_processsor, NNJA_Schema, BUFR_dataloader, _BUFRIterableDataset, FieldMapping, ADPUPA_schema, CRIS_schema \ No newline at end of file diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 1e3fe9a7..5e5b4a87 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -510,5 +510,4 @@ def __init__(self, bufr_loader: BUFR_dataloader): self.bufr_loader = bufr_loader def __iter__(self): - return iter(self.bufr_loader) - + return iter(self.bufr_loader) \ No newline at end of file diff --git a/tests/test_bufr_process.py b/tests/test_bufr_process.py new file mode 100644 index 00000000..565fef67 --- /dev/null +++ b/tests/test_bufr_process.py @@ -0,0 +1,481 @@ +import pytest +import pandas as pd +import numpy as np +import xarray as xr +from pathlib import Path +import tempfile +import json +from unittest.mock import Mock, patch, MagicMock +import sys +import os + +sys.path.append(str(Path(__file__).parent.parent)) + +from graph_weather.data.bufr_process import ( + FieldMapping, + NNJA_Schema, + DataSourceSchema, + ADPUPA_schema, + CRIS_schema, + BUFR_processor, + BUFR_dataloader +) + + +class TestFieldMapping: + """Test FieldMapping dataclass.""" + + def test_field_mapping_creation(self): + """Test FieldMapping initialization.""" + mapping = FieldMapping( + source_name="temperature", + output_name="temp", + dtype=float, + description="Temperature field" + ) + + assert mapping.source_name == "temperature" + assert mapping.output_name == "temp" + assert mapping.dtype == float + assert mapping.description == "Temperature field" + assert mapping.required is True + assert mapping.transform_fn is None + + def test_field_mapping_apply_no_transform(self): + """Test apply method without transformation function.""" + mapping = FieldMapping( + source_name="pressure", + output_name="pres", + dtype=float + ) + + result = mapping.apply(1013.25) + assert result == 1013.25 + + def test_field_mapping_apply_with_transform(self): + """Test apply method with transformation function.""" + mapping = FieldMapping( + source_name="temp_k", + output_name="temp_c", + dtype=float, + transform_fn=lambda x: x - 273.15 + ) + + result = mapping.apply(300.0) + assert result == pytest.approx(26.85) + + def test_field_mapping_apply_none_value(self): + """Test apply method with None value.""" + mapping = FieldMapping( + source_name="missing", + output_name="missing_out", + dtype=float + ) + + result = mapping.apply(None) + assert result is None + + +class TestNNJASchema: + """Test NNJA_Schema class.""" + + def test_schema_structure(self): + """Test schema has expected structure.""" + assert 'OBS_TIMESTAMP' in NNJA_Schema.COORDINATES + assert 'LAT' in NNJA_Schema.COORDINATES + assert 'LON' in NNJA_Schema.COORDINATES + assert 'temperature' in NNJA_Schema.VARIABLES + assert 'pressure' in NNJA_Schema.VARIABLES + + def test_to_xarray_schema(self): + """Test schema combination.""" + full_schema = NNJA_Schema.to_xarray_schema() + assert 'OBS_TIMESTAMP' in full_schema + assert 'temperature' in full_schema + assert 'source' in full_schema + + def test_get_coordinate_names(self): + """Test coordinate names retrieval.""" + coords = NNJA_Schema.get_coordinate_names() + expected_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] + assert set(coords) == set(expected_coords) + + def test_validate_data(self): + """Test data validation.""" + valid_data = { + 'OBS_TIMESTAMP': np.array(['2023-01-01'], dtype='datetime64[ns]'), + 'LAT': np.array([45.0], dtype='float32'), + 'LON': np.array([-120.0], dtype='float32') + } + assert NNJA_Schema.validate_data(valid_data) is True + + invalid_data = { + 'LAT': np.array([45.0], dtype='float32'), + 'LON': np.array([-120.0], dtype='float32') + } + assert NNJA_Schema.validate_data(invalid_data) is False + + +class TestADPUPASchema: + """Test ADPUPA schema implementation.""" + + @pytest.fixture + def adpupa_schema(self): + return ADPUPA_schema() + + def test_schema_creation(self, adpupa_schema): + """Test ADPUPA schema initialization.""" + assert adpupa_schema.source_name == "ADPUPA" + assert len(adpupa_schema.field_mappings) > 0 + + def test_required_mappings_present(self, adpupa_schema): + """Test required NNJA coordinates are mapped.""" + output_names = {m.output_name for m in adpupa_schema.field_mappings.values()} + assert 'LAT' in output_names + assert 'LON' in output_names + assert 'OBS_TIMESTAMP' in output_names + + def test_map_observation(self, adpupa_schema): + """Test observation mapping.""" + test_message = { + 'latitude': 45.0, + 'longitude': -120.0, + 'obsTime': '2023-01-01T12:00:00', + 'airTemperature': 300.0, # Kelvin + 'pressure': 101325.0, + 'height': 100.0, + 'dewpointTemperature': 290.0, # Kelvin + 'windU': 5.0, + 'windV': -3.0 + } + + mapped = adpupa_schema.map_observation(test_message) + + assert mapped['LAT'] == 45.0 + assert mapped['LON'] == -120.0 + assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) + assert mapped['temperature'] == pytest.approx(26.85) # 300K to C + assert mapped['pressure'] == 101325.0 + assert mapped['dew_point'] == pytest.approx(16.85) # 290K to C + + def test_map_observation_missing_fields(self, adpupa_schema): + """Test mapping with missing fields.""" + test_message = { + 'latitude': 45.0, + 'longitude': -120.0, + 'obsTime': '2023-01-01T12:00:00' + } + + mapped = adpupa_schema.map_observation(test_message) + + assert mapped['LAT'] == 45.0 + assert mapped['LON'] == -120.0 + assert mapped['temperature'] is None # Missing field + + +class TestCRISSchema: + """Test CrIS schema implementation.""" + + @pytest.fixture + def cris_schema(self): + return CRIS_schema() + + def test_schema_creation(self, cris_schema): + """Test CrIS schema initialization.""" + assert cris_schema.source_name == "CrIS" + assert len(cris_schema.field_mappings) > 0 + + def test_map_observation(self, cris_schema): + """Test CrIS observation mapping.""" + test_message = { + 'latitude': 30.0, + 'longitude': -100.0, + 'obsTime': 1672574400, # Unix timestamp + 'retrievedTemperature': 280.0, # Kelvin + 'retrievedPressure': 85000.0, + 'qualityFlags': 1 + } + + mapped = cris_schema.map_observation(test_message) + + assert mapped['LAT'] == 30.0 + assert mapped['LON'] == -100.0 + assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) + assert mapped['temperature'] == pytest.approx(6.85) # 280K to C + assert mapped['pressure'] == 85000.0 + assert mapped['qc_flag'] == 1 + + +class TestBUFRProcessor: + """Test BUFR_processor class.""" + + @pytest.fixture + def mock_schema(self): + schema = Mock(spec=DataSourceSchema) + schema.source_name = "TEST" + schema.field_mappings = {} + schema.map_observation.return_value = { + 'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), + 'LAT': 45.0, + 'LON': -120.0, + 'temperature': 20.0 + } + return schema + + @pytest.fixture + def bufr_processor(self, mock_schema): + return BUFR_processor(mock_schema) + + def test_processor_initialization(self, mock_schema): + """Test processor initialization.""" + processor = BUFR_processor(mock_schema) + assert processor.schema == mock_schema + + def test_processor_invalid_schema(self): + """Test processor with invalid schema.""" + with pytest.raises(TypeError): + BUFR_processor("invalid_schema") + + @patch('bufr_processor.eccodes') + def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): + """Test successful BUFR file decoding.""" + # Create a temporary BUFR file + test_file = tmp_path / "test.bufr" + test_file.write_bytes(b"test bufr content") + + # Mock eccodes behavior + mock_bufr_id = Mock() + mock_eccodes.codes_bufr_new_from_file.side_effect = [mock_bufr_id, None] + mock_iterator = Mock() + mock_eccodes.codes_bufr_keys_iterator_new.return_value = mock_iterator + mock_eccodes.codes_bufr_keys_iterator_next.side_effect = [True, False] + mock_eccodes.codes_bufr_keys_iterator_get_name.return_value = "test_key" + mock_eccodes.codes_get_string.return_value = "test_value" + + messages = bufr_processor.decode_bufr_file(str(test_file)) + + assert len(messages) == 1 + assert messages[0]["test_key"] == "test_value" + mock_eccodes.codes_set.assert_called_with(mock_bufr_id, 'unpack', 1) + mock_eccodes.codes_release.assert_called_with(mock_bufr_id) + + @patch('bufr_processor.eccodes') + def test_decode_bufr_file_not_found(self, mock_eccodes, bufr_processor, tmp_path): + """Test BUFR file not found.""" + with pytest.raises(FileNotFoundError): + bufr_processor.decode_bufr_file(str(tmp_path / "nonexistent.bufr")) + + @patch.object(BUFR_processor, 'decode_bufr_file') + def test_process_file_to_dataframe(self, mock_decode, bufr_processor, mock_schema): + """Test processing BUFR file to DataFrame.""" + # Mock decoded messages + mock_messages = [ + {'latitude': 45.0, 'longitude': -120.0, 'obsTime': '2023-01-01T12:00:00'}, + {'latitude': 46.0, 'longitude': -121.0, 'obsTime': '2023-01-01T12:30:00'} + ] + mock_decode.return_value = mock_messages + + # Mock schema mapping + mock_schema.map_observation.side_effect = [ + {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), 'LAT': 45.0, 'LON': -120.0}, + {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:30:00'), 'LAT': 46.0, 'LON': -121.0} + ] + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + df = bufr_processor.process_file_to_dataframe(f.name) + + assert len(df) == 2 + assert 'OBS_TIMESTAMP' in df.columns + assert 'LAT' in df.columns + assert 'LON' in df.columns + assert df['LAT'].iloc[0] == 45.0 + + @patch.object(BUFR_processor, 'decode_bufr_file') + def test_process_file_to_dataframe_empty(self, mock_decode, bufr_processor): + """Test processing BUFR file with no valid observations.""" + mock_decode.return_value = [] # No messages + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + df = bufr_processor.process_file_to_dataframe(f.name) + + assert df.empty + + @patch.object(BUFR_processor, 'process_file_to_dataframe') + def test_process_file_to_xarray(self, mock_process, bufr_processor): + """Test processing BUFR file to xarray Dataset.""" + # Mock DataFrame + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00'), pd.Timestamp('2023-01-01T12:30:00')], + 'LAT': [45.0, 46.0], + 'LON': [-120.0, -121.0], + 'temperature': [20.0, 19.5], + 'pressure': [101325.0, 101300.0] + }) + mock_process.return_value = mock_df + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + ds = bufr_processor.process_file_to_xarray(f.name) + + assert 'temperature' in ds.data_vars + assert 'pressure' in ds.data_vars + assert 'time' in ds.coords + assert 'lat' in ds.coords + assert 'lon' in ds.coords + assert ds.attrs['source'] == 'TEST' + assert ds.attrs['num_observations'] == 2 + + @patch.object(BUFR_processor, 'process_file_to_dataframe') + def test_process_file_to_parquet(self, mock_process, bufr_processor, tmp_path): + """Test processing BUFR file to Parquet.""" + # Mock DataFrame + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], + 'LAT': [45.0], + 'LON': [-120.0], + 'temperature': [20.0] + }) + mock_process.return_value = mock_df + + output_file = tmp_path / "output.parquet" + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + bufr_processor.process_file_to_parquet(f.name, str(output_file)) + + assert output_file.exists() + + # Verify Parquet file can be read back + df_read = pd.read_parquet(output_file) + assert len(df_read) == 1 + assert df_read['LAT'].iloc[0] == 45.0 + + +class TestBUFRDataLoader: + """Test BUFR_dataloader class.""" + + def test_dataloader_initialization_with_schema(self): + """Test dataloader initialization with explicit schema.""" + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + assert loader.schema_name == 'ADPUPA' + assert isinstance(loader.schema, ADPUPA_schema) + assert isinstance(loader.processor, BUFR_processor) + + def test_dataloader_initialization_infer_schema(self): + """Test dataloader initialization with schema inference.""" + test_cases = [ + ('test_adpupa.bufr', 'ADPUPA'), + ('test_CRIS_data.bufr', 'CrIS'), + ('unknown_file.bufr', 'ADPUPA') # Default case + ] + + for filename, expected_schema in test_cases: + with tempfile.NamedTemporaryFile(suffix=filename) as f: + loader = BUFR_dataloader(f.name) + assert loader.schema_name == expected_schema + + def test_dataloader_initialization_invalid_schema(self): + """Test dataloader initialization with invalid schema.""" + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + with pytest.raises(ValueError, match='Unknown schema "INVALID"'): + BUFR_dataloader(f.name, schema_name='INVALID') + + @patch.object(BUFR_processor, 'process_file_to_dataframe') + def test_to_dataframe(self, mock_process): + """Test to_dataframe method.""" + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], + 'LAT': [45.0], + 'LON': [-120.0] + }) + mock_process.return_value = mock_df + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + df = loader.to_dataframe() + + assert len(df) == 1 + mock_process.assert_called_once_with(str(Path(f.name))) + + @patch.object(BUFR_processor, 'process_file_to_xarray') + def test_to_xarray(self, mock_process): + """Test to_xarray method.""" + mock_ds = xr.Dataset({ + 'temperature': (['obs'], [20.0]), + 'pressure': (['obs'], [101325.0]) + }, coords={ + 'obs': [0], + 'time': ('obs', [pd.Timestamp('2023-01-01T12:00:00')]), + 'lat': ('obs', [45.0]), + 'lon': ('obs', [-120.0]) + }) + mock_process.return_value = mock_ds + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + ds = loader.to_xarray() + + assert 'temperature' in ds.data_vars + mock_process.assert_called_once_with(str(Path(f.name))) + + @patch.object(BUFR_processor, 'process_file_to_parquet') + def test_to_parquet(self, mock_process, tmp_path): + """Test to_parquet method.""" + output_file = tmp_path / "test_output.parquet" + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + loader.to_parquet(str(output_file)) + + mock_process.assert_called_once_with(str(Path(f.name)), str(output_file)) + + @patch.object(BUFR_processor, 'decode_bufr_file') + def test_iterator(self, mock_decode): + """Test dataloader iterator.""" + mock_messages = [ + {'lat': 45.0, 'lon': -120.0}, + {'lat': 46.0, 'lon': -121.0} + ] + mock_decode.return_value = mock_messages + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + loader.schema.map_observation = lambda x: {'LAT': x['lat'], 'LON': x['lon']} + + observations = list(loader) + + assert len(observations) == 2 + assert observations[0]['LAT'] == 45.0 + assert observations[1]['LAT'] == 46.0 + + +class TestIntegration: + """Integration tests for the complete pipeline.""" + + def test_schema_registry_completeness(self): + """Test that all schemas in registry can be instantiated.""" + for schema_name, schema_class in BUFR_dataloader.SCHEMA_REGISTRY.items(): + schema = schema_class() + assert isinstance(schema, DataSourceSchema) + assert schema.source_name == schema_name + + def test_end_to_end_mock_processing(self): + """Test complete mock processing pipeline.""" + with tempfile.NamedTemporaryFile(suffix='_adpupa.bufr') as f: + loader = BUFR_dataloader(f.name) + + assert loader.schema_name == 'ADPUPA' + assert isinstance(loader.schema, ADPUPA_schema) + assert isinstance(loader.processor, BUFR_processor) + assert loader.processor.schema == loader.schema + + +def pytest_configure(config): + """Pytest configuration hook.""" + print("Setting up BUFR processor tests...") + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v"]) \ No newline at end of file From a05426a25fbbaab97f214d2d92fd07c335faed7c Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 24 Oct 2025 19:15:51 +0530 Subject: [PATCH 06/13] feat : pytest established, splitting left --- graph_weather/data/__init__.py | 2 +- graph_weather/data/bufr_process.py | 144 +++++++++--------- tests/test_bufr_process.py | 234 +++++++++++++++++++++++------ 3 files changed, 261 insertions(+), 119 deletions(-) diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 79cf9c88..7ca7a18f 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -3,4 +3,4 @@ from .anemoi_dataloader import AnemoiDataset from .nnjaai import SensorDataset from .weather_station_reader import WeatherStationReader -from .bufr_process import BUFR_processsor, NNJA_Schema, BUFR_dataloader, _BUFRIterableDataset, FieldMapping, ADPUPA_schema, CRIS_schema \ No newline at end of file +from .bufr_process import BUFR_processor, NNJA_Schema, BUFR_dataloader, _BUFRIterableDataset, FieldMapping, ADPUPA_schema, CRIS_schema \ No newline at end of file diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 5e5b4a87..ead79b4b 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -127,15 +127,10 @@ def _validate(self): def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: """ Transform raw BUFR message to NNJA-AI format. - - Args: - bufr_message: Decoded BUFR message (dict of field → values) - - Returns: - Observation in NNJA format (dict matching NNJASchema) + Always include all mapped output fields, even if missing from the BUFR message. """ mapped = {} - + for field_map in self.field_mappings.values(): if field_map.source_name in bufr_message: raw_value = bufr_message[field_map.source_name] @@ -143,69 +138,66 @@ def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: value = field_map.apply(raw_value) mapped[field_map.output_name] = value except Exception as e: - logger.warning( - f"Error transforming {field_map.source_name}: {e}" - ) + logger.warning(f"Error transforming {field_map.source_name}: {e}") mapped[field_map.output_name] = None - + else: + # Field not present — default to None + mapped[field_map.output_name] = None + return mapped - - def get_variable_list(self) -> List[str]: - """Get list of NNJA variables this source provides.""" - return list(set(m.output_name for m in self.field_mappings.values())) class ADPUPA_schema(DataSourceSchema): """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI.""" source_name = "ADPUPA" def _build_mappings(self): - self.field_mappings= { - 'latitude' : FieldMapping( + self.field_mappings = { + 'latitude': FieldMapping( source_name='latitude', output_name='LAT', dtype=float, description='Station latitude' ), - 'longitude' : FieldMapping( + 'longitude': FieldMapping( source_name='longitude', output_name='LON', dtype=float, description='Station longitude' ), - 'obsTime' : FieldMapping( + 'obsTime': FieldMapping( source_name='obsTime', - output_name='OBS_TIME', - description='datetime64[ns]', - transform_fn=lambda x: pd.Timestamp(x).value if isinstance(x, str) else x, + output_name='OBS_TIMESTAMP', + dtype=object, + transform_fn=self._convert_timestamp, description='Observation timestamp' - ), - 'airTemperature' : FieldMapping( + ), + 'airTemperature': FieldMapping( source_name='airTemperature', output_name='temperature', dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, + transform_fn=lambda x: x - 273.15 if x > 100 else x, description='Temperature in Celsius' ), - 'pressure' : FieldMapping( + 'pressure': FieldMapping( source_name='pressure', output_name='pressure', dtype=float, description='Pressure in Pa' ), - 'height' : FieldMapping( + 'height': FieldMapping( source_name='height', output_name='height', dtype=float, description='Height above sealevel in m' ), - 'dewpointTemperature' : FieldMapping( + 'dewpointTemperature': FieldMapping( source_name='dewpointTemperature', output_name='dew_point', dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, description='Dew point in Celsius' ), - 'windU' : FieldMapping( + 'windU': FieldMapping( source_name='windU', output_name='u_wind', dtype=float, @@ -225,7 +217,8 @@ def _build_mappings(self): description='Station identifier' ) } - + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): @@ -242,41 +235,46 @@ class CRIS_schema(DataSourceSchema): source_name = "CrIS" def _build_mappings(self): self.field_mappings = { - 'latitude' : FieldMapping( - source_name='lat', + 'latitude': FieldMapping( + source_name='latitude', output_name='LAT', - dtype=float + dtype=float, + description='Satellite latitude' ), - 'longitude' : FieldMapping( - source_name='lon', + 'longitude': FieldMapping( + source_name='longitude', output_name='LON', - dtype=float + dtype=float, + description='Satellite longitude' ), - 'obsTime' : FieldMapping( + 'obsTime': FieldMapping( source_name='obsTime', output_name='OBS_TIMESTAMP', - dtype='datetime64[ns]', - transform_fn=lambda x: pd.Timestamp(x).value, + dtype=object, + transform_fn=self._convert_timestamp, + description='Observation timestamp' ), 'retrievedTemperature': FieldMapping( source_name='retrievedTemperature', output_name='temperature', dtype=float, - transform_fn=lambda x: x - 273.15, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Retrieved temperature in Celsius' ), 'retrievedPressure': FieldMapping( source_name='retrievedPressure', output_name='pressure', dtype=float, + description='Retrieved pressure in Pa' ), - 'sourceZenithAngle': FieldMapping( + 'sensorZenithAngle': FieldMapping( source_name='sensorZenithAngle', output_name='sensor_zenith_angle', dtype=float, required=False, description='Sensor zenith angle' ), - 'qualityFlags' : FieldMapping( + 'qualityFlags': FieldMapping( source_name='qualityFlags', output_name='qc_flag', dtype=int, @@ -292,7 +290,7 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: else: return pd.Timestamp(value) -class BUFR_processsor: +class BUFR_processor: """ Low-level BUFR file decoder. Handles binary BUFR format decoding using eccodes library. @@ -307,50 +305,44 @@ def __init__(self , schema : DataSourceSchema): self.schema = schema - def decoder_bufr_files(self, filepath) -> List[Dict[str,any]]: - """ - Decode all messages from BUFR file. - - Args: - -> filepath: Path to BUFR file - """ + def decoder_bufr_files(self, filepath) -> List[Dict[str, any]]: + """Decode all messages from BUFR file.""" msgs = [] filepath = Path(filepath) - + if not filepath.exists(): raise FileNotFoundError(f"BUFR file not found: {filepath}") - - try: - with open(filepath, 'rb') as f: + + try: + with open(filepath, "rb") as f: while True: - bufr_id = eccodes.bufr_new_from_file(f) + bufr_id = eccodes.codes_bufr_new_from_file(f) if bufr_id is None: - break - + break + try: - eccodes.codes_set(bufr_id, 'unpack', 1) + eccodes.codes_set(bufr_id, "unpack", 1) msg = {} - iterator = eccodes.bufr_keys_iterator(bufr_id) - while eccodes.bufr_keys_iterator_next(iterator): - key = eccodes.bufr_keys_iterator_get_name(iterator) + iterator = eccodes.codes_bufr_keys_iterator_new(bufr_id) + while eccodes.codes_bufr_keys_iterator_next(iterator): + key = eccodes.codes_bufr_keys_iterator_get_name(iterator) try: value = eccodes.codes_get_string(bufr_id, key) msg[key] = value - except (eccodes.KeyValueNotFoundError, eccodes.CodesInternalError): + except Exception: try: value = eccodes.codes_get_double(bufr_id, key) msg[key] = value - except (eccodes.KeyValueNotFoundError, eccodes.CodesInternalError): + except Exception: pass - eccodes.codes_bufr_keys_iterator_delete(iterator) - msgs.append(msg) - finally: + eccodes.codes_bufr_keys_iterator_delete(iterator) + msgs.append(msg) + finally: eccodes.codes_release(bufr_id) - except Exception as e: logger.error(f"Error decoding BUFR file {filepath}: {e}") raise - + logger.info(f"Decoded {len(msgs)} messages from {filepath}") return msgs @@ -370,7 +362,7 @@ def process_files_to_dataframe(self, filepath : str)-> pd.DataFrame: transformed = [] for msg in raw_msgs: - mapped = self.schema.map_observations(msg) + mapped = self.schema.map_observation(msg) if mapped: transformed.append(mapped) @@ -418,7 +410,7 @@ def process_files_to_xarray(self, filepath : str) -> xr.Dataset: data_vars=data_vars, coords={ 'obs' : df.index , - 'time' : ('obs', df['obs'].values), + 'time' : ('obs', df['OBS_TIMESTAMP'].values), 'lat' : ('obs', df['LAT'].values), 'lon' : ('obs', df['LON'].values), } @@ -437,7 +429,7 @@ def process_files_to_parquet(self, filepath: str, output_path: str)->None: -> filepath: Path to BUFR file -> output_path: Path for output Parquet file """ - df = self.process_file_to_dataframe(filepath=filepath) + df = self.process_files_to_dataframe(filepath=filepath) if not df.empty: df.to_parquet(output_path, index=False) @@ -467,7 +459,7 @@ def __init__(self, filepath: str, schema_name: Optional[str] = None): ) self.schema = self.SCHEMA_REGISTRY[self.schema_name]() - self.processor = BUFR_processsor(self.schema) + self.processor = BUFR_processor(self.schema) def _infer_schema_from_path(self) -> str: """Infer schema from filename or path patterns.""" filename = self.filepath.name.lower() @@ -487,19 +479,19 @@ def _infer_schema_from_path(self) -> str: def to_dataframe(self) -> pd.DataFrame: """Process BUFR file to DataFrame.""" - return self.processor.process_file_to_dataframe(str(self.filepath)) + return self.processor.process_files_to_dataframe(str(self.filepath)) def to_xarray(self) -> xr.Dataset: """Process BUFR file to xarray Dataset.""" - return self.processor.process_file_to_xarray(str(self.filepath)) + return self.processor.process_files_to_xarray(str(self.filepath)) def to_parquet(self, output_path: str) -> None: """Process BUFR file to Parquet format.""" - self.processor.process_file_to_parquet(str(self.filepath), output_path) + self.processor.process_files_to_parquet(str(self.filepath), output_path) def __iter__(self) -> Iterator[Dict[str, Any]]: """Iterate over observations in the BUFR file.""" - messages = self.processor.decode_bufr_file(str(self.filepath)) + messages = self.processor.decoder_bufr_files(str(self.filepath)) for msg in messages: yield self.schema.map_observation(msg) diff --git a/tests/test_bufr_process.py b/tests/test_bufr_process.py index 565fef67..58d38915 100644 --- a/tests/test_bufr_process.py +++ b/tests/test_bufr_process.py @@ -8,18 +8,31 @@ from unittest.mock import Mock, patch, MagicMock import sys import os +from typing import Any -sys.path.append(str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent)) -from graph_weather.data.bufr_process import ( - FieldMapping, - NNJA_Schema, - DataSourceSchema, - ADPUPA_schema, - CRIS_schema, - BUFR_processor, - BUFR_dataloader -) +try: + from graph_weather.data.bufr_process import ( + FieldMapping, + NNJA_Schema, + DataSourceSchema, + ADPUPA_schema, + CRIS_schema, + BUFR_processor, + BUFR_dataloader + ) +except ImportError: + # Fallback to direct import + from data.bufr_process import ( + FieldMapping, + NNJA_Schema, + DataSourceSchema, + ADPUPA_schema, + CRIS_schema, + BUFR_processor, + BUFR_dataloader + ) class TestFieldMapping: @@ -116,12 +129,155 @@ def test_validate_data(self): assert NNJA_Schema.validate_data(invalid_data) is False +class MockADPUPASchema(ADPUPA_schema): + """Mock ADPUPA schema with proper FieldMapping dtypes.""" + + def _build_mappings(self): + self.field_mappings = { + 'latitude': FieldMapping( + source_name='latitude', + output_name='LAT', + dtype=float, + description='Station latitude' + ), + 'longitude': FieldMapping( + source_name='longitude', + output_name='LON', + dtype=float, + description='Station longitude' + ), + 'obsTime': FieldMapping( + source_name='obsTime', + output_name='OBS_TIMESTAMP', + dtype=object, + transform_fn=self._convert_timestamp, + description='Observation timestamp' + ), + 'airTemperature': FieldMapping( + source_name='airTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Temperature in Celsius' + ), + 'pressure': FieldMapping( + source_name='pressure', + output_name='pressure', + dtype=float, + description='Pressure in Pa' + ), + 'height': FieldMapping( + source_name='height', + output_name='height', + dtype=float, + description='Height above sealevel in m' + ), + 'dewpointTemperature': FieldMapping( + source_name='dewpointTemperature', + output_name='dew_point', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Dew point in Celsius' + ), + 'windU': FieldMapping( + source_name='windU', + output_name='u_wind', + dtype=float, + description='U-component wind (m/s)' + ), + 'windV': FieldMapping( + source_name='windV', + output_name='v_wind', + dtype=float, + description='V-component wind (m/s)' + ), + 'stationId': FieldMapping( + source_name='stationId', + output_name='station_id', + dtype=str, + required=False, + description='Station identifier' + ) + } + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) + + +class MockCRISSchema(CRIS_schema): + """Mock CrIS schema with proper FieldMapping dtypes.""" + + def _build_mappings(self): + self.field_mappings = { + 'latitude': FieldMapping( + source_name='latitude', + output_name='LAT', + dtype=float, + description='Satellite latitude' + ), + 'longitude': FieldMapping( + source_name='longitude', + output_name='LON', + dtype=float, + description='Satellite longitude' + ), + 'obsTime': FieldMapping( + source_name='obsTime', + output_name='OBS_TIMESTAMP', + dtype=object, + transform_fn=self._convert_timestamp, + description='Observation timestamp' + ), + 'retrievedTemperature': FieldMapping( + source_name='retrievedTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Retrieved temperature in Celsius' + ), + 'retrievedPressure': FieldMapping( + source_name='retrievedPressure', + output_name='pressure', + dtype=float, + description='Retrieved pressure in Pa' + ), + 'sensorZenithAngle': FieldMapping( + source_name='sensorZenithAngle', + output_name='sensor_zenith_angle', + dtype=float, + required=False, + description='Sensor zenith angle' + ), + 'qualityFlags': FieldMapping( + source_name='qualityFlags', + output_name='qc_flag', + dtype=int, + description='Quality control flags' + ) + } + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) + + class TestADPUPASchema: """Test ADPUPA schema implementation.""" @pytest.fixture def adpupa_schema(self): - return ADPUPA_schema() + return MockADPUPASchema() def test_schema_creation(self, adpupa_schema): """Test ADPUPA schema initialization.""" @@ -178,7 +334,7 @@ class TestCRISSchema: @pytest.fixture def cris_schema(self): - return CRIS_schema() + return MockCRISSchema() def test_schema_creation(self, cris_schema): """Test CrIS schema initialization.""" @@ -236,7 +392,7 @@ def test_processor_invalid_schema(self): with pytest.raises(TypeError): BUFR_processor("invalid_schema") - @patch('bufr_processor.eccodes') + @patch('graph_weather.data.bufr_process.eccodes', create=True) def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): """Test successful BUFR file decoding.""" # Create a temporary BUFR file @@ -252,21 +408,19 @@ def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): mock_eccodes.codes_bufr_keys_iterator_get_name.return_value = "test_key" mock_eccodes.codes_get_string.return_value = "test_value" - messages = bufr_processor.decode_bufr_file(str(test_file)) + # Use the correct method name - decoder_bufr_files + messages = bufr_processor.decoder_bufr_files(str(test_file)) assert len(messages) == 1 assert messages[0]["test_key"] == "test_value" - mock_eccodes.codes_set.assert_called_with(mock_bufr_id, 'unpack', 1) - mock_eccodes.codes_release.assert_called_with(mock_bufr_id) - @patch('bufr_processor.eccodes') - def test_decode_bufr_file_not_found(self, mock_eccodes, bufr_processor, tmp_path): + def test_decode_bufr_file_not_found(self, bufr_processor, tmp_path): """Test BUFR file not found.""" with pytest.raises(FileNotFoundError): - bufr_processor.decode_bufr_file(str(tmp_path / "nonexistent.bufr")) + bufr_processor.decoder_bufr_files(str(tmp_path / "nonexistent.bufr")) - @patch.object(BUFR_processor, 'decode_bufr_file') - def test_process_file_to_dataframe(self, mock_decode, bufr_processor, mock_schema): + @patch.object(BUFR_processor, 'decoder_bufr_files') + def test_process_files_to_dataframe(self, mock_decode, bufr_processor, mock_schema): """Test processing BUFR file to DataFrame.""" # Mock decoded messages mock_messages = [ @@ -282,7 +436,7 @@ def test_process_file_to_dataframe(self, mock_decode, bufr_processor, mock_schem ] with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - df = bufr_processor.process_file_to_dataframe(f.name) + df = bufr_processor.process_files_to_dataframe(f.name) assert len(df) == 2 assert 'OBS_TIMESTAMP' in df.columns @@ -290,18 +444,18 @@ def test_process_file_to_dataframe(self, mock_decode, bufr_processor, mock_schem assert 'LON' in df.columns assert df['LAT'].iloc[0] == 45.0 - @patch.object(BUFR_processor, 'decode_bufr_file') - def test_process_file_to_dataframe_empty(self, mock_decode, bufr_processor): + @patch.object(BUFR_processor, 'decoder_bufr_files') + def test_process_files_to_dataframe_empty(self, mock_decode, bufr_processor): """Test processing BUFR file with no valid observations.""" mock_decode.return_value = [] # No messages with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - df = bufr_processor.process_file_to_dataframe(f.name) + df = bufr_processor.process_files_to_dataframe(f.name) assert df.empty - @patch.object(BUFR_processor, 'process_file_to_dataframe') - def test_process_file_to_xarray(self, mock_process, bufr_processor): + @patch.object(BUFR_processor, 'process_files_to_dataframe') + def test_process_files_to_xarray(self, mock_process, bufr_processor): """Test processing BUFR file to xarray Dataset.""" # Mock DataFrame mock_df = pd.DataFrame({ @@ -314,7 +468,7 @@ def test_process_file_to_xarray(self, mock_process, bufr_processor): mock_process.return_value = mock_df with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - ds = bufr_processor.process_file_to_xarray(f.name) + ds = bufr_processor.process_files_to_xarray(f.name) assert 'temperature' in ds.data_vars assert 'pressure' in ds.data_vars @@ -322,10 +476,9 @@ def test_process_file_to_xarray(self, mock_process, bufr_processor): assert 'lat' in ds.coords assert 'lon' in ds.coords assert ds.attrs['source'] == 'TEST' - assert ds.attrs['num_observations'] == 2 - @patch.object(BUFR_processor, 'process_file_to_dataframe') - def test_process_file_to_parquet(self, mock_process, bufr_processor, tmp_path): + @patch.object(BUFR_processor, 'process_files_to_dataframe') + def test_process_files_to_parquet(self, mock_process, bufr_processor, tmp_path): """Test processing BUFR file to Parquet.""" # Mock DataFrame mock_df = pd.DataFrame({ @@ -339,14 +492,9 @@ def test_process_file_to_parquet(self, mock_process, bufr_processor, tmp_path): output_file = tmp_path / "output.parquet" with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - bufr_processor.process_file_to_parquet(f.name, str(output_file)) + bufr_processor.process_files_to_parquet(f.name, str(output_file)) assert output_file.exists() - - # Verify Parquet file can be read back - df_read = pd.read_parquet(output_file) - assert len(df_read) == 1 - assert df_read['LAT'].iloc[0] == 45.0 class TestBUFRDataLoader: @@ -380,7 +528,7 @@ def test_dataloader_initialization_invalid_schema(self): with pytest.raises(ValueError, match='Unknown schema "INVALID"'): BUFR_dataloader(f.name, schema_name='INVALID') - @patch.object(BUFR_processor, 'process_file_to_dataframe') + @patch.object(BUFR_processor, 'process_files_to_dataframe') def test_to_dataframe(self, mock_process): """Test to_dataframe method.""" mock_df = pd.DataFrame({ @@ -397,7 +545,7 @@ def test_to_dataframe(self, mock_process): assert len(df) == 1 mock_process.assert_called_once_with(str(Path(f.name))) - @patch.object(BUFR_processor, 'process_file_to_xarray') + @patch.object(BUFR_processor, 'process_files_to_xarray') def test_to_xarray(self, mock_process): """Test to_xarray method.""" mock_ds = xr.Dataset({ @@ -418,7 +566,7 @@ def test_to_xarray(self, mock_process): assert 'temperature' in ds.data_vars mock_process.assert_called_once_with(str(Path(f.name))) - @patch.object(BUFR_processor, 'process_file_to_parquet') + @patch.object(BUFR_processor, 'process_files_to_parquet') def test_to_parquet(self, mock_process, tmp_path): """Test to_parquet method.""" output_file = tmp_path / "test_output.parquet" @@ -429,7 +577,7 @@ def test_to_parquet(self, mock_process, tmp_path): mock_process.assert_called_once_with(str(Path(f.name)), str(output_file)) - @patch.object(BUFR_processor, 'decode_bufr_file') + @patch.object(BUFR_processor, 'decoder_bufr_files') def test_iterator(self, mock_decode): """Test dataloader iterator.""" mock_messages = [ @@ -441,6 +589,7 @@ def test_iterator(self, mock_decode): with tempfile.NamedTemporaryFile(suffix='.bufr') as f: loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + # Mock the schema's map_observation to return simple data loader.schema.map_observation = lambda x: {'LAT': x['lat'], 'LON': x['lon']} observations = list(loader) @@ -463,7 +612,7 @@ def test_schema_registry_completeness(self): def test_end_to_end_mock_processing(self): """Test complete mock processing pipeline.""" with tempfile.NamedTemporaryFile(suffix='_adpupa.bufr') as f: - loader = BUFR_dataloader(f.name) + loader = BUFR_dataloader(f.name) # Should infer ADPUPA schema assert loader.schema_name == 'ADPUPA' assert isinstance(loader.schema, ADPUPA_schema) @@ -471,6 +620,7 @@ def test_end_to_end_mock_processing(self): assert loader.processor.schema == loader.schema +# Test configuration for running with different options def pytest_configure(config): """Pytest configuration hook.""" print("Setting up BUFR processor tests...") From 55e6ef19a6342b672abc3f5a9757972e216d6b94 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Fri, 24 Oct 2025 22:36:05 +0530 Subject: [PATCH 07/13] chore : distrbuted files for better readability --- tests/bufr_process/__init__.py | 0 tests/bufr_process/conftest.py | 53 ++ tests/bufr_process/test_adpupa_schema.py | 48 ++ tests/bufr_process/test_common.py | 146 ++++++ tests/bufr_process/test_cris.py | 27 + tests/bufr_process/test_dataloader.py | 109 ++++ tests/bufr_process/test_field_mapping.py | 57 ++ tests/bufr_process/test_integration.py | 24 + tests/bufr_process/test_nnja_schema.py | 43 ++ tests/bufr_process/test_processor.py | 128 +++++ tests/test_bufr_process.py | 631 ----------------------- tests/test_nnjai.py | 2 +- 12 files changed, 636 insertions(+), 632 deletions(-) create mode 100644 tests/bufr_process/__init__.py create mode 100644 tests/bufr_process/conftest.py create mode 100644 tests/bufr_process/test_adpupa_schema.py create mode 100644 tests/bufr_process/test_common.py create mode 100644 tests/bufr_process/test_cris.py create mode 100644 tests/bufr_process/test_dataloader.py create mode 100644 tests/bufr_process/test_field_mapping.py create mode 100644 tests/bufr_process/test_integration.py create mode 100644 tests/bufr_process/test_nnja_schema.py create mode 100644 tests/bufr_process/test_processor.py delete mode 100644 tests/test_bufr_process.py diff --git a/tests/bufr_process/__init__.py b/tests/bufr_process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bufr_process/conftest.py b/tests/bufr_process/conftest.py new file mode 100644 index 00000000..1f133c24 --- /dev/null +++ b/tests/bufr_process/conftest.py @@ -0,0 +1,53 @@ +import pytest +import pandas as pd +import numpy as np +from unittest.mock import Mock, patch +import tempfile +from pathlib import Path +import sys +import os + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from graph_weather.data.bufr_process import DataSourceSchema + +@pytest.fixture +def mock_schema(): + """Mock schema for testing.""" + schema = Mock(spec=DataSourceSchema) + schema.source_name = "TEST" + schema.field_mappings = {} + schema.map_observation.return_value = { + 'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), + 'LAT': 45.0, + 'LON': -120.0, + 'temperature': 20.0 + } + return schema + +@pytest.fixture +def sample_adpupa_data(): + """Sample ADPUPA test data.""" + return { + 'latitude': 45.0, + 'longitude': -120.0, + 'obsTime': '2023-01-01T12:00:00', + 'airTemperature': 300.0, + 'pressure': 101325.0, + 'height': 100.0, + 'dewpointTemperature': 290.0, + 'windU': 5.0, + 'windV': -3.0 + } + +@pytest.fixture +def sample_cris_data(): + """Sample CrIS test data.""" + return { + 'latitude': 30.0, + 'longitude': -100.0, + 'obsTime': 1672574400, + 'retrievedTemperature': 280.0, + 'retrievedPressure': 85000.0, + 'qualityFlags': 1 + } \ No newline at end of file diff --git a/tests/bufr_process/test_adpupa_schema.py b/tests/bufr_process/test_adpupa_schema.py new file mode 100644 index 00000000..b585aabc --- /dev/null +++ b/tests/bufr_process/test_adpupa_schema.py @@ -0,0 +1,48 @@ +import pytest +import pandas as pd +from .test_common import MockADPUPASchema + + +class TestADPUPASchema: + """Test ADPUPA schema implementation.""" + + @pytest.fixture + def adpupa_schema(self): + return MockADPUPASchema() + + def test_schema_creation(self, adpupa_schema): + """Test ADPUPA schema initialization.""" + assert adpupa_schema.source_name == "ADPUPA" + assert len(adpupa_schema.field_mappings) > 0 + + def test_required_mappings_present(self, adpupa_schema): + """Test required NNJA coordinates are mapped.""" + output_names = {m.output_name for m in adpupa_schema.field_mappings.values()} + assert 'LAT' in output_names + assert 'LON' in output_names + assert 'OBS_TIMESTAMP' in output_names + + def test_map_observation(self, adpupa_schema, sample_adpupa_data): + """Test observation mapping.""" + mapped = adpupa_schema.map_observation(sample_adpupa_data) + + assert mapped['LAT'] == 45.0 + assert mapped['LON'] == -120.0 + assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) + assert mapped['temperature'] == pytest.approx(26.85) # 300K to C + assert mapped['pressure'] == 101325.0 + assert mapped['dew_point'] == pytest.approx(16.85) # 290K to C + + def test_map_observation_missing_fields(self, adpupa_schema): + """Test mapping with missing fields.""" + test_message = { + 'latitude': 45.0, + 'longitude': -120.0, + 'obsTime': '2023-01-01T12:00:00' + } + + mapped = adpupa_schema.map_observation(test_message) + + assert mapped['LAT'] == 45.0 + assert mapped['LON'] == -120.0 + assert mapped['temperature'] is None # Missing field \ No newline at end of file diff --git a/tests/bufr_process/test_common.py b/tests/bufr_process/test_common.py new file mode 100644 index 00000000..c979227c --- /dev/null +++ b/tests/bufr_process/test_common.py @@ -0,0 +1,146 @@ +import pandas as pd +from graph_weather.data.bufr_process import FieldMapping, ADPUPA_schema, CRIS_schema +from typing import Any + + +class MockADPUPASchema(ADPUPA_schema): + """Mock ADPUPA schema with proper FieldMapping dtypes.""" + + def _build_mappings(self): + self.field_mappings = { + 'latitude': FieldMapping( + source_name='latitude', + output_name='LAT', + dtype=float, + description='Station latitude' + ), + 'longitude': FieldMapping( + source_name='longitude', + output_name='LON', + dtype=float, + description='Station longitude' + ), + 'obsTime': FieldMapping( + source_name='obsTime', + output_name='OBS_TIMESTAMP', + dtype=object, + transform_fn=self._convert_timestamp, + description='Observation timestamp' + ), + 'airTemperature': FieldMapping( + source_name='airTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Temperature in Celsius' + ), + 'pressure': FieldMapping( + source_name='pressure', + output_name='pressure', + dtype=float, + description='Pressure in Pa' + ), + 'height': FieldMapping( + source_name='height', + output_name='height', + dtype=float, + description='Height above sealevel in m' + ), + 'dewpointTemperature': FieldMapping( + source_name='dewpointTemperature', + output_name='dew_point', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Dew point in Celsius' + ), + 'windU': FieldMapping( + source_name='windU', + output_name='u_wind', + dtype=float, + description='U-component wind (m/s)' + ), + 'windV': FieldMapping( + source_name='windV', + output_name='v_wind', + dtype=float, + description='V-component wind (m/s)' + ), + 'stationId': FieldMapping( + source_name='stationId', + output_name='station_id', + dtype=str, + required=False, + description='Station identifier' + ) + } + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) + + +class MockCRISSchema(CRIS_schema): + """Mock CrIS schema with proper FieldMapping dtypes.""" + + def _build_mappings(self): + self.field_mappings = { + 'latitude': FieldMapping( + source_name='latitude', + output_name='LAT', + dtype=float, + description='Satellite latitude' + ), + 'longitude': FieldMapping( + source_name='longitude', + output_name='LON', + dtype=float, + description='Satellite longitude' + ), + 'obsTime': FieldMapping( + source_name='obsTime', + output_name='OBS_TIMESTAMP', + dtype=object, + transform_fn=self._convert_timestamp, + description='Observation timestamp' + ), + 'retrievedTemperature': FieldMapping( + source_name='retrievedTemperature', + output_name='temperature', + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description='Retrieved temperature in Celsius' + ), + 'retrievedPressure': FieldMapping( + source_name='retrievedPressure', + output_name='pressure', + dtype=float, + description='Retrieved pressure in Pa' + ), + 'sensorZenithAngle': FieldMapping( + source_name='sensorZenithAngle', + output_name='sensor_zenith_angle', + dtype=float, + required=False, + description='Sensor zenith angle' + ), + 'qualityFlags': FieldMapping( + source_name='qualityFlags', + output_name='qc_flag', + dtype=int, + description='Quality control flags' + ) + } + + def _convert_timestamp(self, value: Any) -> pd.Timestamp: + """Convert BUFR timestamp to pandas Timestamp.""" + if isinstance(value, (int, float)): + return pd.Timestamp(value, unit='s') + elif isinstance(value, str): + return pd.Timestamp(value) + else: + return pd.Timestamp(value) \ No newline at end of file diff --git a/tests/bufr_process/test_cris.py b/tests/bufr_process/test_cris.py new file mode 100644 index 00000000..0380ed88 --- /dev/null +++ b/tests/bufr_process/test_cris.py @@ -0,0 +1,27 @@ +import pytest +import pandas as pd +from .test_common import MockCRISSchema + + +class TestCRISSchema: + """Test CrIS schema implementation.""" + + @pytest.fixture + def cris_schema(self): + return MockCRISSchema() + + def test_schema_creation(self, cris_schema): + """Test CrIS schema initialization.""" + assert cris_schema.source_name == "CrIS" + assert len(cris_schema.field_mappings) > 0 + + def test_map_observation(self, cris_schema, sample_cris_data): + """Test CrIS observation mapping.""" + mapped = cris_schema.map_observation(sample_cris_data) + + assert mapped['LAT'] == 30.0 + assert mapped['LON'] == -100.0 + assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) + assert mapped['temperature'] == pytest.approx(6.85) # 280K to C + assert mapped['pressure'] == 85000.0 + assert mapped['qc_flag'] == 1 \ No newline at end of file diff --git a/tests/bufr_process/test_dataloader.py b/tests/bufr_process/test_dataloader.py new file mode 100644 index 00000000..cde83f94 --- /dev/null +++ b/tests/bufr_process/test_dataloader.py @@ -0,0 +1,109 @@ +import pytest +import pandas as pd +import xarray as xr +import tempfile +from pathlib import Path +from unittest.mock import patch +from graph_weather.data.bufr_process import BUFR_dataloader, ADPUPA_schema, BUFR_processor + + +class TestBUFRDataLoader: + """Test BUFR_dataloader class.""" + + def test_dataloader_initialization_with_schema(self): + """Test dataloader initialization with explicit schema.""" + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + assert loader.schema_name == 'ADPUPA' + assert isinstance(loader.schema, ADPUPA_schema) + assert isinstance(loader.processor, BUFR_processor) + + def test_dataloader_initialization_infer_schema(self): + """Test dataloader initialization with schema inference.""" + test_cases = [ + ('test_adpupa.bufr', 'ADPUPA'), + ('test_CRIS_data.bufr', 'CrIS'), + ('unknown_file.bufr', 'ADPUPA') # Default case + ] + + for filename, expected_schema in test_cases: + with tempfile.NamedTemporaryFile(suffix=filename) as f: + loader = BUFR_dataloader(f.name) + assert loader.schema_name == expected_schema + + def test_dataloader_initialization_invalid_schema(self): + """Test dataloader initialization with invalid schema.""" + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + with pytest.raises(ValueError, match='Unknown schema "INVALID"'): + BUFR_dataloader(f.name, schema_name='INVALID') + + @patch.object(BUFR_processor, 'process_files_to_dataframe') + def test_to_dataframe(self, mock_process): + """Test to_dataframe method.""" + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], + 'LAT': [45.0], + 'LON': [-120.0] + }) + mock_process.return_value = mock_df + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + df = loader.to_dataframe() + + assert len(df) == 1 + mock_process.assert_called_once_with(str(Path(f.name))) + + @patch.object(BUFR_processor, 'process_files_to_xarray') + def test_to_xarray(self, mock_process): + """Test to_xarray method.""" + mock_ds = xr.Dataset({ + 'temperature': (['obs'], [20.0]), + 'pressure': (['obs'], [101325.0]) + }, coords={ + 'obs': [0], + 'time': ('obs', [pd.Timestamp('2023-01-01T12:00:00')]), + 'lat': ('obs', [45.0]), + 'lon': ('obs', [-120.0]) + }) + mock_process.return_value = mock_ds + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + ds = loader.to_xarray() + + assert 'temperature' in ds.data_vars + mock_process.assert_called_once_with(str(Path(f.name))) + + @patch.object(BUFR_processor, 'process_files_to_parquet') + def test_to_parquet(self, mock_process, tmp_path): + """Test to_parquet method.""" + output_file = tmp_path / "test_output.parquet" + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + loader.to_parquet(str(output_file)) + + mock_process.assert_called_once_with(str(Path(f.name)), str(output_file)) + + @patch.object(BUFR_processor, 'decoder_bufr_files') + def test_iterator(self, mock_decode): + """Test dataloader iterator.""" + mock_messages = [ + {'lat': 45.0, 'lon': -120.0}, + {'lat': 46.0, 'lon': -121.0} + ] + mock_decode.return_value = mock_messages + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + # Mock the schema's map_observation to return simple data + loader.schema.map_observation = lambda x: {'LAT': x['lat'], 'LON': x['lon']} + + observations = list(loader) + + assert len(observations) == 2 + assert observations[0]['LAT'] == 45.0 + assert observations[1]['LAT'] == 46.0 \ No newline at end of file diff --git a/tests/bufr_process/test_field_mapping.py b/tests/bufr_process/test_field_mapping.py new file mode 100644 index 00000000..dfa50dd8 --- /dev/null +++ b/tests/bufr_process/test_field_mapping.py @@ -0,0 +1,57 @@ +import pytest +import pandas as pd +from graph_weather.data.bufr_process import FieldMapping + + +class TestFieldMapping: + """Test FieldMapping dataclass.""" + + def test_field_mapping_creation(self): + """Test FieldMapping initialization.""" + mapping = FieldMapping( + source_name="temperature", + output_name="temp", + dtype=float, + description="Temperature field" + ) + + assert mapping.source_name == "temperature" + assert mapping.output_name == "temp" + assert mapping.dtype == float + assert mapping.description == "Temperature field" + assert mapping.required is True + assert mapping.transform_fn is None + + def test_field_mapping_apply_no_transform(self): + """Test apply method without transformation function.""" + mapping = FieldMapping( + source_name="pressure", + output_name="pres", + dtype=float + ) + + result = mapping.apply(1013.25) + assert result == 1013.25 + + def test_field_mapping_apply_with_transform(self): + """Test apply method with transformation function.""" + mapping = FieldMapping( + source_name="temp_k", + output_name="temp_c", + dtype=float, + transform_fn=lambda x: x - 273.15 + ) + + result = mapping.apply(300.0) + assert result == pytest.approx(26.85) + + def test_field_mapping_apply_none_value(self): + """Test apply method with None value.""" + mapping = FieldMapping( + source_name="missing", + output_name="missing_out", + dtype=float + ) + + result = mapping.apply(None) + assert result is None \ No newline at end of file diff --git a/tests/bufr_process/test_integration.py b/tests/bufr_process/test_integration.py new file mode 100644 index 00000000..3842f718 --- /dev/null +++ b/tests/bufr_process/test_integration.py @@ -0,0 +1,24 @@ +import pytest +import tempfile +from graph_weather.data.bufr_process import BUFR_dataloader, DataSourceSchema, ADPUPA_schema, CRIS_schema, BUFR_processor + + +class TestIntegration: + """Integration tests for the complete pipeline.""" + + def test_schema_registry_completeness(self): + """Test that all schemas in registry can be instantiated.""" + for schema_name, schema_class in BUFR_dataloader.SCHEMA_REGISTRY.items(): + schema = schema_class() + assert isinstance(schema, DataSourceSchema) + assert schema.source_name == schema_name + + def test_end_to_end_mock_processing(self): + """Test complete mock processing pipeline.""" + with tempfile.NamedTemporaryFile(suffix='_adpupa.bufr') as f: + loader = BUFR_dataloader(f.name) + + assert loader.schema_name == 'ADPUPA' + assert isinstance(loader.schema, ADPUPA_schema) + assert isinstance(loader.processor, BUFR_processor) + assert loader.processor.schema == loader.schema \ No newline at end of file diff --git a/tests/bufr_process/test_nnja_schema.py b/tests/bufr_process/test_nnja_schema.py new file mode 100644 index 00000000..146e64b8 --- /dev/null +++ b/tests/bufr_process/test_nnja_schema.py @@ -0,0 +1,43 @@ +import pytest +import numpy as np +from graph_weather.data.bufr_process import NNJA_Schema + + +class TestNNJASchema: + """Test NNJA_Schema class.""" + + def test_schema_structure(self): + """Test schema has expected structure.""" + assert 'OBS_TIMESTAMP' in NNJA_Schema.COORDINATES + assert 'LAT' in NNJA_Schema.COORDINATES + assert 'LON' in NNJA_Schema.COORDINATES + assert 'temperature' in NNJA_Schema.VARIABLES + assert 'pressure' in NNJA_Schema.VARIABLES + + def test_to_xarray_schema(self): + """Test schema combination.""" + full_schema = NNJA_Schema.to_xarray_schema() + assert 'OBS_TIMESTAMP' in full_schema + assert 'temperature' in full_schema + assert 'source' in full_schema + + def test_get_coordinate_names(self): + """Test coordinate names retrieval.""" + coords = NNJA_Schema.get_coordinate_names() + expected_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] + assert set(coords) == set(expected_coords) + + def test_validate_data(self): + """Test data validation.""" + valid_data = { + 'OBS_TIMESTAMP': np.array(['2023-01-01'], dtype='datetime64[ns]'), + 'LAT': np.array([45.0], dtype='float32'), + 'LON': np.array([-120.0], dtype='float32') + } + assert NNJA_Schema.validate_data(valid_data) is True + + invalid_data = { + 'LAT': np.array([45.0], dtype='float32'), + 'LON': np.array([-120.0], dtype='float32') + } + assert NNJA_Schema.validate_data(invalid_data) is False \ No newline at end of file diff --git a/tests/bufr_process/test_processor.py b/tests/bufr_process/test_processor.py new file mode 100644 index 00000000..32c23465 --- /dev/null +++ b/tests/bufr_process/test_processor.py @@ -0,0 +1,128 @@ +import pytest +import pandas as pd +import xarray as xr +import tempfile +from pathlib import Path +from unittest.mock import patch , Mock +from graph_weather.data.bufr_process import BUFR_processor + + +class TestBUFRProcessor: + """Test BUFR_processor class.""" + + @pytest.fixture + def bufr_processor(self, mock_schema): + return BUFR_processor(mock_schema) + + def test_processor_initialization(self, mock_schema): + """Test processor initialization.""" + processor = BUFR_processor(mock_schema) + assert processor.schema == mock_schema + + def test_processor_invalid_schema(self): + """Test processor with invalid schema.""" + with pytest.raises(TypeError): + BUFR_processor("invalid_schema") + + @patch('graph_weather.data.bufr_process.eccodes', create=True) + def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): + """Test successful BUFR file decoding.""" + test_file = tmp_path / "test.bufr" + test_file.write_bytes(b"test bufr content") + + # Mock eccodes behavior + mock_bufr_id = Mock() + mock_eccodes.codes_bufr_new_from_file.side_effect = [mock_bufr_id, None] + mock_iterator = Mock() + mock_eccodes.codes_bufr_keys_iterator_new.return_value = mock_iterator + mock_eccodes.codes_bufr_keys_iterator_next.side_effect = [True, False] + mock_eccodes.codes_bufr_keys_iterator_get_name.return_value = "test_key" + mock_eccodes.codes_get_string.return_value = "test_value" + + # Use the correct method name - decoder_bufr_files + messages = bufr_processor.decoder_bufr_files(str(test_file)) + + assert len(messages) == 1 + assert messages[0]["test_key"] == "test_value" + + def test_decode_bufr_file_not_found(self, bufr_processor, tmp_path): + """Test BUFR file not found.""" + with pytest.raises(FileNotFoundError): + bufr_processor.decoder_bufr_files(str(tmp_path / "nonexistent.bufr")) + + @patch.object(BUFR_processor, 'decoder_bufr_files') + def test_process_files_to_dataframe(self, mock_decode, bufr_processor, mock_schema): + """Test processing BUFR file to DataFrame.""" + # Mock decoded messages + mock_messages = [ + {'latitude': 45.0, 'longitude': -120.0, 'obsTime': '2023-01-01T12:00:00'}, + {'latitude': 46.0, 'longitude': -121.0, 'obsTime': '2023-01-01T12:30:00'} + ] + mock_decode.return_value = mock_messages + + # Mock schema mapping + mock_schema.map_observation.side_effect = [ + {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), 'LAT': 45.0, 'LON': -120.0}, + {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:30:00'), 'LAT': 46.0, 'LON': -121.0} + ] + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + df = bufr_processor.process_files_to_dataframe(f.name) + + assert len(df) == 2 + assert 'OBS_TIMESTAMP' in df.columns + assert 'LAT' in df.columns + assert 'LON' in df.columns + assert df['LAT'].iloc[0] == 45.0 + + @patch.object(BUFR_processor, 'decoder_bufr_files') + def test_process_files_to_dataframe_empty(self, mock_decode, bufr_processor): + """Test processing BUFR file with no valid observations.""" + mock_decode.return_value = [] # No messages + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + df = bufr_processor.process_files_to_dataframe(f.name) + + assert df.empty + + @patch.object(BUFR_processor, 'process_files_to_dataframe') + def test_process_files_to_xarray(self, mock_process, bufr_processor): + """Test processing BUFR file to xarray Dataset.""" + # Mock DataFrame + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00'), pd.Timestamp('2023-01-01T12:30:00')], + 'LAT': [45.0, 46.0], + 'LON': [-120.0, -121.0], + 'temperature': [20.0, 19.5], + 'pressure': [101325.0, 101300.0] + }) + mock_process.return_value = mock_df + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + ds = bufr_processor.process_files_to_xarray(f.name) + + assert 'temperature' in ds.data_vars + assert 'pressure' in ds.data_vars + assert 'time' in ds.coords + assert 'lat' in ds.coords + assert 'lon' in ds.coords + assert ds.attrs['source'] == 'TEST' + + @patch.object(BUFR_processor, 'process_files_to_dataframe') + def test_process_files_to_parquet(self, mock_process, bufr_processor, tmp_path): + """Test processing BUFR file to Parquet.""" + # Mock DataFrame + mock_df = pd.DataFrame({ + 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], + 'LAT': [45.0], + 'LON': [-120.0], + 'temperature': [20.0] + }) + mock_process.return_value = mock_df + + output_file = tmp_path / "output.parquet" + + with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + bufr_processor.process_files_to_parquet(f.name, str(output_file)) + + assert output_file.exists() \ No newline at end of file diff --git a/tests/test_bufr_process.py b/tests/test_bufr_process.py deleted file mode 100644 index 58d38915..00000000 --- a/tests/test_bufr_process.py +++ /dev/null @@ -1,631 +0,0 @@ -import pytest -import pandas as pd -import numpy as np -import xarray as xr -from pathlib import Path -import tempfile -import json -from unittest.mock import Mock, patch, MagicMock -import sys -import os -from typing import Any - -sys.path.insert(0, str(Path(__file__).parent)) - -try: - from graph_weather.data.bufr_process import ( - FieldMapping, - NNJA_Schema, - DataSourceSchema, - ADPUPA_schema, - CRIS_schema, - BUFR_processor, - BUFR_dataloader - ) -except ImportError: - # Fallback to direct import - from data.bufr_process import ( - FieldMapping, - NNJA_Schema, - DataSourceSchema, - ADPUPA_schema, - CRIS_schema, - BUFR_processor, - BUFR_dataloader - ) - - -class TestFieldMapping: - """Test FieldMapping dataclass.""" - - def test_field_mapping_creation(self): - """Test FieldMapping initialization.""" - mapping = FieldMapping( - source_name="temperature", - output_name="temp", - dtype=float, - description="Temperature field" - ) - - assert mapping.source_name == "temperature" - assert mapping.output_name == "temp" - assert mapping.dtype == float - assert mapping.description == "Temperature field" - assert mapping.required is True - assert mapping.transform_fn is None - - def test_field_mapping_apply_no_transform(self): - """Test apply method without transformation function.""" - mapping = FieldMapping( - source_name="pressure", - output_name="pres", - dtype=float - ) - - result = mapping.apply(1013.25) - assert result == 1013.25 - - def test_field_mapping_apply_with_transform(self): - """Test apply method with transformation function.""" - mapping = FieldMapping( - source_name="temp_k", - output_name="temp_c", - dtype=float, - transform_fn=lambda x: x - 273.15 - ) - - result = mapping.apply(300.0) - assert result == pytest.approx(26.85) - - def test_field_mapping_apply_none_value(self): - """Test apply method with None value.""" - mapping = FieldMapping( - source_name="missing", - output_name="missing_out", - dtype=float - ) - - result = mapping.apply(None) - assert result is None - - -class TestNNJASchema: - """Test NNJA_Schema class.""" - - def test_schema_structure(self): - """Test schema has expected structure.""" - assert 'OBS_TIMESTAMP' in NNJA_Schema.COORDINATES - assert 'LAT' in NNJA_Schema.COORDINATES - assert 'LON' in NNJA_Schema.COORDINATES - assert 'temperature' in NNJA_Schema.VARIABLES - assert 'pressure' in NNJA_Schema.VARIABLES - - def test_to_xarray_schema(self): - """Test schema combination.""" - full_schema = NNJA_Schema.to_xarray_schema() - assert 'OBS_TIMESTAMP' in full_schema - assert 'temperature' in full_schema - assert 'source' in full_schema - - def test_get_coordinate_names(self): - """Test coordinate names retrieval.""" - coords = NNJA_Schema.get_coordinate_names() - expected_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] - assert set(coords) == set(expected_coords) - - def test_validate_data(self): - """Test data validation.""" - valid_data = { - 'OBS_TIMESTAMP': np.array(['2023-01-01'], dtype='datetime64[ns]'), - 'LAT': np.array([45.0], dtype='float32'), - 'LON': np.array([-120.0], dtype='float32') - } - assert NNJA_Schema.validate_data(valid_data) is True - - invalid_data = { - 'LAT': np.array([45.0], dtype='float32'), - 'LON': np.array([-120.0], dtype='float32') - } - assert NNJA_Schema.validate_data(invalid_data) is False - - -class MockADPUPASchema(ADPUPA_schema): - """Mock ADPUPA schema with proper FieldMapping dtypes.""" - - def _build_mappings(self): - self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', - dtype=float, - description='Station latitude' - ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', - dtype=float, - description='Station longitude' - ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', - dtype=object, - transform_fn=self._convert_timestamp, - description='Observation timestamp' - ), - 'airTemperature': FieldMapping( - source_name='airTemperature', - output_name='temperature', - dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Temperature in Celsius' - ), - 'pressure': FieldMapping( - source_name='pressure', - output_name='pressure', - dtype=float, - description='Pressure in Pa' - ), - 'height': FieldMapping( - source_name='height', - output_name='height', - dtype=float, - description='Height above sealevel in m' - ), - 'dewpointTemperature': FieldMapping( - source_name='dewpointTemperature', - output_name='dew_point', - dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Dew point in Celsius' - ), - 'windU': FieldMapping( - source_name='windU', - output_name='u_wind', - dtype=float, - description='U-component wind (m/s)' - ), - 'windV': FieldMapping( - source_name='windV', - output_name='v_wind', - dtype=float, - description='V-component wind (m/s)' - ), - 'stationId': FieldMapping( - source_name='stationId', - output_name='station_id', - dtype=str, - required=False, - description='Station identifier' - ) - } - - def _convert_timestamp(self, value: Any) -> pd.Timestamp: - """Convert BUFR timestamp to pandas Timestamp.""" - if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') - elif isinstance(value, str): - return pd.Timestamp(value) - else: - return pd.Timestamp(value) - - -class MockCRISSchema(CRIS_schema): - """Mock CrIS schema with proper FieldMapping dtypes.""" - - def _build_mappings(self): - self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', - dtype=float, - description='Satellite latitude' - ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', - dtype=float, - description='Satellite longitude' - ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', - dtype=object, - transform_fn=self._convert_timestamp, - description='Observation timestamp' - ), - 'retrievedTemperature': FieldMapping( - source_name='retrievedTemperature', - output_name='temperature', - dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Retrieved temperature in Celsius' - ), - 'retrievedPressure': FieldMapping( - source_name='retrievedPressure', - output_name='pressure', - dtype=float, - description='Retrieved pressure in Pa' - ), - 'sensorZenithAngle': FieldMapping( - source_name='sensorZenithAngle', - output_name='sensor_zenith_angle', - dtype=float, - required=False, - description='Sensor zenith angle' - ), - 'qualityFlags': FieldMapping( - source_name='qualityFlags', - output_name='qc_flag', - dtype=int, - description='Quality control flags' - ) - } - - def _convert_timestamp(self, value: Any) -> pd.Timestamp: - """Convert BUFR timestamp to pandas Timestamp.""" - if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') - elif isinstance(value, str): - return pd.Timestamp(value) - else: - return pd.Timestamp(value) - - -class TestADPUPASchema: - """Test ADPUPA schema implementation.""" - - @pytest.fixture - def adpupa_schema(self): - return MockADPUPASchema() - - def test_schema_creation(self, adpupa_schema): - """Test ADPUPA schema initialization.""" - assert adpupa_schema.source_name == "ADPUPA" - assert len(adpupa_schema.field_mappings) > 0 - - def test_required_mappings_present(self, adpupa_schema): - """Test required NNJA coordinates are mapped.""" - output_names = {m.output_name for m in adpupa_schema.field_mappings.values()} - assert 'LAT' in output_names - assert 'LON' in output_names - assert 'OBS_TIMESTAMP' in output_names - - def test_map_observation(self, adpupa_schema): - """Test observation mapping.""" - test_message = { - 'latitude': 45.0, - 'longitude': -120.0, - 'obsTime': '2023-01-01T12:00:00', - 'airTemperature': 300.0, # Kelvin - 'pressure': 101325.0, - 'height': 100.0, - 'dewpointTemperature': 290.0, # Kelvin - 'windU': 5.0, - 'windV': -3.0 - } - - mapped = adpupa_schema.map_observation(test_message) - - assert mapped['LAT'] == 45.0 - assert mapped['LON'] == -120.0 - assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) - assert mapped['temperature'] == pytest.approx(26.85) # 300K to C - assert mapped['pressure'] == 101325.0 - assert mapped['dew_point'] == pytest.approx(16.85) # 290K to C - - def test_map_observation_missing_fields(self, adpupa_schema): - """Test mapping with missing fields.""" - test_message = { - 'latitude': 45.0, - 'longitude': -120.0, - 'obsTime': '2023-01-01T12:00:00' - } - - mapped = adpupa_schema.map_observation(test_message) - - assert mapped['LAT'] == 45.0 - assert mapped['LON'] == -120.0 - assert mapped['temperature'] is None # Missing field - - -class TestCRISSchema: - """Test CrIS schema implementation.""" - - @pytest.fixture - def cris_schema(self): - return MockCRISSchema() - - def test_schema_creation(self, cris_schema): - """Test CrIS schema initialization.""" - assert cris_schema.source_name == "CrIS" - assert len(cris_schema.field_mappings) > 0 - - def test_map_observation(self, cris_schema): - """Test CrIS observation mapping.""" - test_message = { - 'latitude': 30.0, - 'longitude': -100.0, - 'obsTime': 1672574400, # Unix timestamp - 'retrievedTemperature': 280.0, # Kelvin - 'retrievedPressure': 85000.0, - 'qualityFlags': 1 - } - - mapped = cris_schema.map_observation(test_message) - - assert mapped['LAT'] == 30.0 - assert mapped['LON'] == -100.0 - assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) - assert mapped['temperature'] == pytest.approx(6.85) # 280K to C - assert mapped['pressure'] == 85000.0 - assert mapped['qc_flag'] == 1 - - -class TestBUFRProcessor: - """Test BUFR_processor class.""" - - @pytest.fixture - def mock_schema(self): - schema = Mock(spec=DataSourceSchema) - schema.source_name = "TEST" - schema.field_mappings = {} - schema.map_observation.return_value = { - 'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), - 'LAT': 45.0, - 'LON': -120.0, - 'temperature': 20.0 - } - return schema - - @pytest.fixture - def bufr_processor(self, mock_schema): - return BUFR_processor(mock_schema) - - def test_processor_initialization(self, mock_schema): - """Test processor initialization.""" - processor = BUFR_processor(mock_schema) - assert processor.schema == mock_schema - - def test_processor_invalid_schema(self): - """Test processor with invalid schema.""" - with pytest.raises(TypeError): - BUFR_processor("invalid_schema") - - @patch('graph_weather.data.bufr_process.eccodes', create=True) - def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): - """Test successful BUFR file decoding.""" - # Create a temporary BUFR file - test_file = tmp_path / "test.bufr" - test_file.write_bytes(b"test bufr content") - - # Mock eccodes behavior - mock_bufr_id = Mock() - mock_eccodes.codes_bufr_new_from_file.side_effect = [mock_bufr_id, None] - mock_iterator = Mock() - mock_eccodes.codes_bufr_keys_iterator_new.return_value = mock_iterator - mock_eccodes.codes_bufr_keys_iterator_next.side_effect = [True, False] - mock_eccodes.codes_bufr_keys_iterator_get_name.return_value = "test_key" - mock_eccodes.codes_get_string.return_value = "test_value" - - # Use the correct method name - decoder_bufr_files - messages = bufr_processor.decoder_bufr_files(str(test_file)) - - assert len(messages) == 1 - assert messages[0]["test_key"] == "test_value" - - def test_decode_bufr_file_not_found(self, bufr_processor, tmp_path): - """Test BUFR file not found.""" - with pytest.raises(FileNotFoundError): - bufr_processor.decoder_bufr_files(str(tmp_path / "nonexistent.bufr")) - - @patch.object(BUFR_processor, 'decoder_bufr_files') - def test_process_files_to_dataframe(self, mock_decode, bufr_processor, mock_schema): - """Test processing BUFR file to DataFrame.""" - # Mock decoded messages - mock_messages = [ - {'latitude': 45.0, 'longitude': -120.0, 'obsTime': '2023-01-01T12:00:00'}, - {'latitude': 46.0, 'longitude': -121.0, 'obsTime': '2023-01-01T12:30:00'} - ] - mock_decode.return_value = mock_messages - - # Mock schema mapping - mock_schema.map_observation.side_effect = [ - {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), 'LAT': 45.0, 'LON': -120.0}, - {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:30:00'), 'LAT': 46.0, 'LON': -121.0} - ] - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - df = bufr_processor.process_files_to_dataframe(f.name) - - assert len(df) == 2 - assert 'OBS_TIMESTAMP' in df.columns - assert 'LAT' in df.columns - assert 'LON' in df.columns - assert df['LAT'].iloc[0] == 45.0 - - @patch.object(BUFR_processor, 'decoder_bufr_files') - def test_process_files_to_dataframe_empty(self, mock_decode, bufr_processor): - """Test processing BUFR file with no valid observations.""" - mock_decode.return_value = [] # No messages - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - df = bufr_processor.process_files_to_dataframe(f.name) - - assert df.empty - - @patch.object(BUFR_processor, 'process_files_to_dataframe') - def test_process_files_to_xarray(self, mock_process, bufr_processor): - """Test processing BUFR file to xarray Dataset.""" - # Mock DataFrame - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00'), pd.Timestamp('2023-01-01T12:30:00')], - 'LAT': [45.0, 46.0], - 'LON': [-120.0, -121.0], - 'temperature': [20.0, 19.5], - 'pressure': [101325.0, 101300.0] - }) - mock_process.return_value = mock_df - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - ds = bufr_processor.process_files_to_xarray(f.name) - - assert 'temperature' in ds.data_vars - assert 'pressure' in ds.data_vars - assert 'time' in ds.coords - assert 'lat' in ds.coords - assert 'lon' in ds.coords - assert ds.attrs['source'] == 'TEST' - - @patch.object(BUFR_processor, 'process_files_to_dataframe') - def test_process_files_to_parquet(self, mock_process, bufr_processor, tmp_path): - """Test processing BUFR file to Parquet.""" - # Mock DataFrame - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], - 'LAT': [45.0], - 'LON': [-120.0], - 'temperature': [20.0] - }) - mock_process.return_value = mock_df - - output_file = tmp_path / "output.parquet" - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - bufr_processor.process_files_to_parquet(f.name, str(output_file)) - - assert output_file.exists() - - -class TestBUFRDataLoader: - """Test BUFR_dataloader class.""" - - def test_dataloader_initialization_with_schema(self): - """Test dataloader initialization with explicit schema.""" - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - - assert loader.schema_name == 'ADPUPA' - assert isinstance(loader.schema, ADPUPA_schema) - assert isinstance(loader.processor, BUFR_processor) - - def test_dataloader_initialization_infer_schema(self): - """Test dataloader initialization with schema inference.""" - test_cases = [ - ('test_adpupa.bufr', 'ADPUPA'), - ('test_CRIS_data.bufr', 'CrIS'), - ('unknown_file.bufr', 'ADPUPA') # Default case - ] - - for filename, expected_schema in test_cases: - with tempfile.NamedTemporaryFile(suffix=filename) as f: - loader = BUFR_dataloader(f.name) - assert loader.schema_name == expected_schema - - def test_dataloader_initialization_invalid_schema(self): - """Test dataloader initialization with invalid schema.""" - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - with pytest.raises(ValueError, match='Unknown schema "INVALID"'): - BUFR_dataloader(f.name, schema_name='INVALID') - - @patch.object(BUFR_processor, 'process_files_to_dataframe') - def test_to_dataframe(self, mock_process): - """Test to_dataframe method.""" - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], - 'LAT': [45.0], - 'LON': [-120.0] - }) - mock_process.return_value = mock_df - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - df = loader.to_dataframe() - - assert len(df) == 1 - mock_process.assert_called_once_with(str(Path(f.name))) - - @patch.object(BUFR_processor, 'process_files_to_xarray') - def test_to_xarray(self, mock_process): - """Test to_xarray method.""" - mock_ds = xr.Dataset({ - 'temperature': (['obs'], [20.0]), - 'pressure': (['obs'], [101325.0]) - }, coords={ - 'obs': [0], - 'time': ('obs', [pd.Timestamp('2023-01-01T12:00:00')]), - 'lat': ('obs', [45.0]), - 'lon': ('obs', [-120.0]) - }) - mock_process.return_value = mock_ds - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - ds = loader.to_xarray() - - assert 'temperature' in ds.data_vars - mock_process.assert_called_once_with(str(Path(f.name))) - - @patch.object(BUFR_processor, 'process_files_to_parquet') - def test_to_parquet(self, mock_process, tmp_path): - """Test to_parquet method.""" - output_file = tmp_path / "test_output.parquet" - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - loader.to_parquet(str(output_file)) - - mock_process.assert_called_once_with(str(Path(f.name)), str(output_file)) - - @patch.object(BUFR_processor, 'decoder_bufr_files') - def test_iterator(self, mock_decode): - """Test dataloader iterator.""" - mock_messages = [ - {'lat': 45.0, 'lon': -120.0}, - {'lat': 46.0, 'lon': -121.0} - ] - mock_decode.return_value = mock_messages - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - - # Mock the schema's map_observation to return simple data - loader.schema.map_observation = lambda x: {'LAT': x['lat'], 'LON': x['lon']} - - observations = list(loader) - - assert len(observations) == 2 - assert observations[0]['LAT'] == 45.0 - assert observations[1]['LAT'] == 46.0 - - -class TestIntegration: - """Integration tests for the complete pipeline.""" - - def test_schema_registry_completeness(self): - """Test that all schemas in registry can be instantiated.""" - for schema_name, schema_class in BUFR_dataloader.SCHEMA_REGISTRY.items(): - schema = schema_class() - assert isinstance(schema, DataSourceSchema) - assert schema.source_name == schema_name - - def test_end_to_end_mock_processing(self): - """Test complete mock processing pipeline.""" - with tempfile.NamedTemporaryFile(suffix='_adpupa.bufr') as f: - loader = BUFR_dataloader(f.name) # Should infer ADPUPA schema - - assert loader.schema_name == 'ADPUPA' - assert isinstance(loader.schema, ADPUPA_schema) - assert isinstance(loader.processor, BUFR_processor) - assert loader.processor.schema == loader.schema - - -# Test configuration for running with different options -def pytest_configure(config): - """Pytest configuration hook.""" - print("Setting up BUFR processor tests...") - - -if __name__ == "__main__": - # Run tests directly - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py index c6985118..b4c7cc24 100644 --- a/tests/test_nnjai.py +++ b/tests/test_nnjai.py @@ -22,7 +22,7 @@ @pytest.fixture def mock_datacatalog(): """Fixture to mock the DataCatalog with properly configured variables.""" - with patch("graph_weather.data.nnja_ai.DataCatalog") as mock: + with patch("graph_weather.data.nnjaai.DataCatalog") as mock: mock_catalog = MagicMock() mock_dataset = MagicMock() mock_dataset.load_manifest = MagicMock() From ea055d96a8feed92ea357e549cfb8e3ebd7ce5c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 06:51:24 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/data/__init__.py | 10 +- graph_weather/data/bufr_process.py | 360 ++++++++++++----------- tests/bufr_process/conftest.py | 45 +-- tests/bufr_process/test_adpupa_schema.py | 46 ++- tests/bufr_process/test_common.py | 158 +++++----- tests/bufr_process/test_cris.py | 20 +- tests/bufr_process/test_dataloader.py | 117 ++++---- tests/bufr_process/test_field_mapping.py | 34 +-- tests/bufr_process/test_integration.py | 22 +- tests/bufr_process/test_nnja_schema.py | 40 +-- tests/bufr_process/test_processor.py | 127 ++++---- 11 files changed, 498 insertions(+), 481 deletions(-) diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 7ca7a18f..08843ae7 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1,6 +1,14 @@ """Dataloaders and data processing utilities""" from .anemoi_dataloader import AnemoiDataset +from .bufr_process import ( + ADPUPA_schema, + BUFR_dataloader, + BUFR_processor, + CRIS_schema, + FieldMapping, + NNJA_Schema, + _BUFRIterableDataset, +) from .nnjaai import SensorDataset from .weather_station_reader import WeatherStationReader -from .bufr_process import BUFR_processor, NNJA_Schema, BUFR_dataloader, _BUFRIterableDataset, FieldMapping, ADPUPA_schema, CRIS_schema \ No newline at end of file diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index ead79b4b..818fd328 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -1,17 +1,18 @@ -from dataclasses import dataclass, field -from typing import Optional, Callable, Any , List, Dict, Iterator +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Optional + import numpy as np -import logging -import pandas as pd +import pandas as pd import xarray as xr -from pathlib import Path logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("Bufr Processor") -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import IterableDataset try: import eccodes @@ -22,13 +23,14 @@ @dataclass class FieldMapping: """Maps a BUFR source field to NNJA-AI output field.""" + source_name: str output_name: str dtype: type transform_fn: Optional[Callable] = None required: bool = True description: str = "" - + def apply(self, value: Any) -> Any: """Apply transformation to source value.""" if value is None: @@ -41,50 +43,53 @@ def apply(self, value: Any) -> Any: class NNJA_Schema: """ Defines the canonical NNJA-AI schema that all BUFR data maps to - + Mimics NNJA-AI's xarray format with standardized coordinates and variables. Coordinate system matches NNJA-AI: - OBS_TIMESTAMP: observation time (ns precision) - LAT: latitude - LON: longitude """ + COORDINATES = { - 'OBS_TIMESTAMP': 'datetime64[ns]', - 'LAT': 'float32', - 'LON': 'float32', + "OBS_TIMESTAMP": "datetime64[ns]", + "LAT": "float32", + "LON": "float32", } VARIABLES = { - 'temperature': 'float32', - 'pressure': 'float32', - 'relative_humidity': 'float32', - 'u_wind': 'float32', - 'v_wind': 'float32', - 'dew_point': 'float32', - 'height': 'float32', + "temperature": "float32", + "pressure": "float32", + "relative_humidity": "float32", + "u_wind": "float32", + "v_wind": "float32", + "dew_point": "float32", + "height": "float32", } - + ATTRIBUTES = { - 'source': 'DATA_SOURCE', - 'qc_flag': 'int8', - 'processing_timestamp': 'datetime64[ns]', + "source": "DATA_SOURCE", + "qc_flag": "int8", + "processing_timestamp": "datetime64[ns]", } + @classmethod def to_xarray_schema(cls) -> Dict[str, str]: """Get full schema as dict for xarray construction.""" return {**cls.COORDINATES, **cls.VARIABLES, **cls.ATTRIBUTES} - + @classmethod def get_coordinate_names(cls) -> List[str]: """Get list of coordinate names.""" return list(cls.COORDINATES.keys()) - + @classmethod def validate_data(cls, data: Dict[str, np.ndarray]) -> bool: """Check if data has required NNJA coordinates.""" - required_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] + required_coords = ["OBS_TIMESTAMP", "LAT", "LON"] return all(coord in data for coord in required_coords) + class DataSourceSchema: """ Abstract base for source-specific BUFR schema mappings. @@ -92,17 +97,17 @@ class DataSourceSchema: map to NNJA-AI canonical format. """ - source_name: str = "unknown" - + def __init__(self): self.field_mappings: Dict[str, FieldMapping] = {} self._build_mappings() self._validate() - + def _build_mappings(self): """ Override in subclasses to define BUFR → NNJA field mappings. + Example: self.field_mappings['T'] = FieldMapping( source_name='T', @@ -113,17 +118,15 @@ def _build_mappings(self): ) """ raise NotImplementedError("Subclasses must implement _build_mappings()") - + def _validate(self): """Ensure all required NNJA coordinates are mapped.""" - required = ['OBS_TIMESTAMP', 'LAT', 'LON'] + required = ["OBS_TIMESTAMP", "LAT", "LON"] mapped_outputs = {m.output_name for m in self.field_mappings.values()} missing = [r for r in required if r not in mapped_outputs] if missing: - logger.warning( - f"{self.source_name} schema missing required outputs: {missing}" - ) - + logger.warning(f"{self.source_name} schema missing required outputs: {missing}") + def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: """ Transform raw BUFR message to NNJA-AI format. @@ -146,83 +149,84 @@ def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: return mapped + class ADPUPA_schema(DataSourceSchema): """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI.""" source_name = "ADPUPA" + def _build_mappings(self): self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', + "latitude": FieldMapping( + source_name="latitude", + output_name="LAT", dtype=float, - description='Station latitude' + description="Station latitude", ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', + "longitude": FieldMapping( + source_name="longitude", + output_name="LON", dtype=float, - description='Station longitude' + description="Station longitude", ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', + "obsTime": FieldMapping( + source_name="obsTime", + output_name="OBS_TIMESTAMP", dtype=object, transform_fn=self._convert_timestamp, - description='Observation timestamp' + description="Observation timestamp", ), - 'airTemperature': FieldMapping( - source_name='airTemperature', - output_name='temperature', + "airTemperature": FieldMapping( + source_name="airTemperature", + output_name="temperature", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Temperature in Celsius' + description="Temperature in Celsius", ), - 'pressure': FieldMapping( - source_name='pressure', - output_name='pressure', + "pressure": FieldMapping( + source_name="pressure", + output_name="pressure", dtype=float, - description='Pressure in Pa' + description="Pressure in Pa", ), - 'height': FieldMapping( - source_name='height', - output_name='height', + "height": FieldMapping( + source_name="height", + output_name="height", dtype=float, - description='Height above sealevel in m' + description="Height above sealevel in m", ), - 'dewpointTemperature': FieldMapping( - source_name='dewpointTemperature', - output_name='dew_point', + "dewpointTemperature": FieldMapping( + source_name="dewpointTemperature", + output_name="dew_point", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Dew point in Celsius' + description="Dew point in Celsius", ), - 'windU': FieldMapping( - source_name='windU', - output_name='u_wind', + "windU": FieldMapping( + source_name="windU", + output_name="u_wind", dtype=float, - description='U-component wind (m/s)' + description="U-component wind (m/s)", ), - 'windV': FieldMapping( - source_name='windV', - output_name='v_wind', + "windV": FieldMapping( + source_name="windV", + output_name="v_wind", dtype=float, - description='V-component wind (m/s)' + description="V-component wind (m/s)", ), - 'stationId': FieldMapping( - source_name='stationId', - output_name='station_id', + "stationId": FieldMapping( + source_name="stationId", + output_name="station_id", dtype=str, required=False, - description='Station identifier' - ) + description="Station identifier", + ), } - def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') + return pd.Timestamp(value, unit="s") elif isinstance(value, str): return pd.Timestamp(value) else: @@ -231,80 +235,84 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: class CRIS_schema(DataSourceSchema): """CrIS (satellite hyperspectral) BUFR schema mapping to NNJA-AI.""" - + source_name = "CrIS" + def _build_mappings(self): self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', + "latitude": FieldMapping( + source_name="latitude", + output_name="LAT", dtype=float, - description='Satellite latitude' + description="Satellite latitude", ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', + "longitude": FieldMapping( + source_name="longitude", + output_name="LON", dtype=float, - description='Satellite longitude' + description="Satellite longitude", ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', + "obsTime": FieldMapping( + source_name="obsTime", + output_name="OBS_TIMESTAMP", dtype=object, transform_fn=self._convert_timestamp, - description='Observation timestamp' + description="Observation timestamp", ), - 'retrievedTemperature': FieldMapping( - source_name='retrievedTemperature', - output_name='temperature', + "retrievedTemperature": FieldMapping( + source_name="retrievedTemperature", + output_name="temperature", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Retrieved temperature in Celsius' + description="Retrieved temperature in Celsius", ), - 'retrievedPressure': FieldMapping( - source_name='retrievedPressure', - output_name='pressure', + "retrievedPressure": FieldMapping( + source_name="retrievedPressure", + output_name="pressure", dtype=float, - description='Retrieved pressure in Pa' + description="Retrieved pressure in Pa", ), - 'sensorZenithAngle': FieldMapping( - source_name='sensorZenithAngle', - output_name='sensor_zenith_angle', + "sensorZenithAngle": FieldMapping( + source_name="sensorZenithAngle", + output_name="sensor_zenith_angle", dtype=float, required=False, - description='Sensor zenith angle' + description="Sensor zenith angle", ), - 'qualityFlags': FieldMapping( - source_name='qualityFlags', - output_name='qc_flag', + "qualityFlags": FieldMapping( + source_name="qualityFlags", + output_name="qc_flag", dtype=int, - description='Quality control flags' - ) + description="Quality control flags", + ), } + def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') + return pd.Timestamp(value, unit="s") elif isinstance(value, str): return pd.Timestamp(value) else: return pd.Timestamp(value) + class BUFR_processor: """ Low-level BUFR file decoder. Handles binary BUFR format decoding using eccodes library. """ - def __init__(self , schema : DataSourceSchema): + + def __init__(self, schema: DataSourceSchema): """ - Args: - -> schema : DataSourceSchema instance + Args: + -> schema : DataSourceSchema instance """ - if not isinstance(schema,DataSourceSchema): - raise TypeError('schema must be of DataSourceSchema instance') - - self.schema = schema - + if not isinstance(schema, DataSourceSchema): + raise TypeError("schema must be of DataSourceSchema instance") + + self.schema = schema + def decoder_bufr_files(self, filepath) -> List[Dict[str, any]]: """Decode all messages from BUFR file.""" msgs = [] @@ -346,104 +354,104 @@ def decoder_bufr_files(self, filepath) -> List[Dict[str, any]]: logger.info(f"Decoded {len(msgs)} messages from {filepath}") return msgs - def process_files_to_dataframe(self, filepath : str)-> pd.DataFrame: + def process_files_to_dataframe(self, filepath: str) -> pd.DataFrame: """ Decode BUFR file and map to NNJA schema, return as DataFrame. - + Args: -> filepath: Path to BUFR file - + Returns: -> pandas DataFrame in NNJA-AI format """ raw_msgs = self.decoder_bufr_files(filepath=filepath) - + transformed = [] - + for msg in raw_msgs: mapped = self.schema.map_observation(msg) if mapped: transformed.append(mapped) - + if not transformed: logger.warning(f"No valid observations found in {filepath}") return pd.DataFrame() - df = pd.DataFrame(transformed) - + for col in df.columns: if col in NNJA_Schema.COORDINATES: dtype = NNJA_Schema.COORDINATES[col] - if 'datetime' in dtype: + if "datetime" in dtype: df[col] = pd.to_datetime(df[col]) else: - df[col] = df[col].astype(dtype.split('[')[0] if '[' in dtype else dtype) - if not NNJA_Schema.validate_data(df.to_dict(orient='list')): + df[col] = df[col].astype(dtype.split("[")[0] if "[" in dtype else dtype) + if not NNJA_Schema.validate_data(df.to_dict(orient="list")): logger.warning(f"DataFrame missing required NNJA coordinates from {filepath}") - + return df - def process_files_to_xarray(self, filepath : str) -> xr.Dataset: + def process_files_to_xarray(self, filepath: str) -> xr.Dataset: """ Process BUFR file to xarray Dataset in NNJA-AI format. - + Args: -> filepath: Path to BUFR file - + Returns: -> xarray Dataset in NNJA-AI format """ df = self.process_files_to_dataframe(filepath=filepath) - + if df.empty: logger.warning(f"No data to convert to xarray from {filepath}") return xr.Dataset() - + data_vars = {} for col in df.columns: - if col not in ['OBS_TIMESTAMP', 'LAT', 'LON']: - data_vars[col] = (['observation'], df[col].values) - + if col not in ["OBS_TIMESTAMP", "LAT", "LON"]: + data_vars[col] = (["observation"], df[col].values) + ds = xr.Dataset( data_vars=data_vars, coords={ - 'obs' : df.index , - 'time' : ('obs', df['OBS_TIMESTAMP'].values), - 'lat' : ('obs', df['LAT'].values), - 'lon' : ('obs', df['LON'].values), - } + "obs": df.index, + "time": ("obs", df["OBS_TIMESTAMP"].values), + "lat": ("obs", df["LAT"].values), + "lon": ("obs", df["LON"].values), + }, ) - ds.attrs['source'] = self.schema.source_name - ds.attrs['processing_timestamp'] = pd.Timestamp.now().isoformat() - ds.attrs['num_observations'] = len(df) - + ds.attrs["source"] = self.schema.source_name + ds.attrs["processing_timestamp"] = pd.Timestamp.now().isoformat() + ds.attrs["num_observations"] = len(df) + return ds - - def process_files_to_parquet(self, filepath: str, output_path: str)->None: + + def process_files_to_parquet(self, filepath: str, output_path: str) -> None: """ Process BUFR file and save as Parquet in NNJA-AI format. - + Args: -> filepath: Path to BUFR file -> output_path: Path for output Parquet file """ df = self.process_files_to_dataframe(filepath=filepath) - + if not df.empty: df.to_parquet(output_path, index=False) logger.info(f"Saved {len(df)} observations to {output_path}") else: logger.warning(f"No data to save for {filepath}") - -class BUFR_dataloader: - - SCHEMA_REGISTRY={ - 'ADPUPA': ADPUPA_schema, - 'CrIS': CRIS_schema, - } + +class BUFR_dataloader: + + SCHEMA_REGISTRY = { + "ADPUPA": ADPUPA_schema, + "CrIS": CRIS_schema, + } + def __init__(self, filepath: str, schema_name: Optional[str] = None): """ Args: @@ -452,54 +460,56 @@ def __init__(self, filepath: str, schema_name: Optional[str] = None): """ self.filepath = Path(filepath) self.schema_name = schema_name or self._infer_schema_from_path() - + if self.schema_name not in self.SCHEMA_REGISTRY: raise ValueError( f'Unknown schema "{self.schema_name}". Available: {list(self.SCHEMA_REGISTRY.keys())}' ) - + self.schema = self.SCHEMA_REGISTRY[self.schema_name]() self.processor = BUFR_processor(self.schema) + def _infer_schema_from_path(self) -> str: """Infer schema from filename or path patterns.""" filename = self.filepath.name.lower() - - if 'adpupa' in filename or 'raob' in filename or 'sound' in filename: - return 'ADPUPA' - elif 'cris' in filename: - return 'CrIS' - elif 'iasi' in filename: - return 'IASI' - elif 'atms' in filename: - return 'ATMS' + + if "adpupa" in filename or "raob" in filename or "sound" in filename: + return "ADPUPA" + elif "cris" in filename: + return "CrIS" + elif "iasi" in filename: + return "IASI" + elif "atms" in filename: + return "ATMS" else: # Default to ADPUPA for now logger.warning(f"Could not infer schema from {filename}, defaulting to ADPUPA") - return 'ADPUPA' - + return "ADPUPA" + def to_dataframe(self) -> pd.DataFrame: """Process BUFR file to DataFrame.""" return self.processor.process_files_to_dataframe(str(self.filepath)) - + def to_xarray(self) -> xr.Dataset: """Process BUFR file to xarray Dataset.""" return self.processor.process_files_to_xarray(str(self.filepath)) - + def to_parquet(self, output_path: str) -> None: """Process BUFR file to Parquet format.""" self.processor.process_files_to_parquet(str(self.filepath), output_path) - + def __iter__(self) -> Iterator[Dict[str, Any]]: """Iterate over observations in the BUFR file.""" messages = self.processor.decoder_bufr_files(str(self.filepath)) for msg in messages: yield self.schema.map_observation(msg) + class _BUFRIterableDataset(IterableDataset): """Internal IterableDataset wrapper for PyTorch DataLoader.""" - + def __init__(self, bufr_loader: BUFR_dataloader): self.bufr_loader = bufr_loader - + def __iter__(self): - return iter(self.bufr_loader) \ No newline at end of file + return iter(self.bufr_loader) diff --git a/tests/bufr_process/conftest.py b/tests/bufr_process/conftest.py index 1f133c24..db891654 100644 --- a/tests/bufr_process/conftest.py +++ b/tests/bufr_process/conftest.py @@ -9,7 +9,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from graph_weather.data.bufr_process import DataSourceSchema +from graph_weather.data.bufr_process import DataSourceSchema + @pytest.fixture def mock_schema(): @@ -18,36 +19,38 @@ def mock_schema(): schema.source_name = "TEST" schema.field_mappings = {} schema.map_observation.return_value = { - 'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), - 'LAT': 45.0, - 'LON': -120.0, - 'temperature': 20.0 + "OBS_TIMESTAMP": pd.Timestamp("2023-01-01T12:00:00"), + "LAT": 45.0, + "LON": -120.0, + "temperature": 20.0, } return schema + @pytest.fixture def sample_adpupa_data(): """Sample ADPUPA test data.""" return { - 'latitude': 45.0, - 'longitude': -120.0, - 'obsTime': '2023-01-01T12:00:00', - 'airTemperature': 300.0, - 'pressure': 101325.0, - 'height': 100.0, - 'dewpointTemperature': 290.0, - 'windU': 5.0, - 'windV': -3.0 + "latitude": 45.0, + "longitude": -120.0, + "obsTime": "2023-01-01T12:00:00", + "airTemperature": 300.0, + "pressure": 101325.0, + "height": 100.0, + "dewpointTemperature": 290.0, + "windU": 5.0, + "windV": -3.0, } + @pytest.fixture def sample_cris_data(): """Sample CrIS test data.""" return { - 'latitude': 30.0, - 'longitude': -100.0, - 'obsTime': 1672574400, - 'retrievedTemperature': 280.0, - 'retrievedPressure': 85000.0, - 'qualityFlags': 1 - } \ No newline at end of file + "latitude": 30.0, + "longitude": -100.0, + "obsTime": 1672574400, + "retrievedTemperature": 280.0, + "retrievedPressure": 85000.0, + "qualityFlags": 1, + } diff --git a/tests/bufr_process/test_adpupa_schema.py b/tests/bufr_process/test_adpupa_schema.py index b585aabc..a540c054 100644 --- a/tests/bufr_process/test_adpupa_schema.py +++ b/tests/bufr_process/test_adpupa_schema.py @@ -5,44 +5,40 @@ class TestADPUPASchema: """Test ADPUPA schema implementation.""" - + @pytest.fixture def adpupa_schema(self): return MockADPUPASchema() - + def test_schema_creation(self, adpupa_schema): """Test ADPUPA schema initialization.""" assert adpupa_schema.source_name == "ADPUPA" assert len(adpupa_schema.field_mappings) > 0 - + def test_required_mappings_present(self, adpupa_schema): """Test required NNJA coordinates are mapped.""" output_names = {m.output_name for m in adpupa_schema.field_mappings.values()} - assert 'LAT' in output_names - assert 'LON' in output_names - assert 'OBS_TIMESTAMP' in output_names - + assert "LAT" in output_names + assert "LON" in output_names + assert "OBS_TIMESTAMP" in output_names + def test_map_observation(self, adpupa_schema, sample_adpupa_data): """Test observation mapping.""" mapped = adpupa_schema.map_observation(sample_adpupa_data) - - assert mapped['LAT'] == 45.0 - assert mapped['LON'] == -120.0 - assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) - assert mapped['temperature'] == pytest.approx(26.85) # 300K to C - assert mapped['pressure'] == 101325.0 - assert mapped['dew_point'] == pytest.approx(16.85) # 290K to C - + + assert mapped["LAT"] == 45.0 + assert mapped["LON"] == -120.0 + assert isinstance(mapped["OBS_TIMESTAMP"], pd.Timestamp) + assert mapped["temperature"] == pytest.approx(26.85) # 300K to C + assert mapped["pressure"] == 101325.0 + assert mapped["dew_point"] == pytest.approx(16.85) # 290K to C + def test_map_observation_missing_fields(self, adpupa_schema): """Test mapping with missing fields.""" - test_message = { - 'latitude': 45.0, - 'longitude': -120.0, - 'obsTime': '2023-01-01T12:00:00' - } - + test_message = {"latitude": 45.0, "longitude": -120.0, "obsTime": "2023-01-01T12:00:00"} + mapped = adpupa_schema.map_observation(test_message) - - assert mapped['LAT'] == 45.0 - assert mapped['LON'] == -120.0 - assert mapped['temperature'] is None # Missing field \ No newline at end of file + + assert mapped["LAT"] == 45.0 + assert mapped["LON"] == -120.0 + assert mapped["temperature"] is None # Missing field diff --git a/tests/bufr_process/test_common.py b/tests/bufr_process/test_common.py index c979227c..ed64e3c8 100644 --- a/tests/bufr_process/test_common.py +++ b/tests/bufr_process/test_common.py @@ -5,79 +5,79 @@ class MockADPUPASchema(ADPUPA_schema): """Mock ADPUPA schema with proper FieldMapping dtypes.""" - + def _build_mappings(self): self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', + "latitude": FieldMapping( + source_name="latitude", + output_name="LAT", dtype=float, - description='Station latitude' + description="Station latitude", ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', + "longitude": FieldMapping( + source_name="longitude", + output_name="LON", dtype=float, - description='Station longitude' + description="Station longitude", ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', + "obsTime": FieldMapping( + source_name="obsTime", + output_name="OBS_TIMESTAMP", dtype=object, transform_fn=self._convert_timestamp, - description='Observation timestamp' - ), - 'airTemperature': FieldMapping( - source_name='airTemperature', - output_name='temperature', + description="Observation timestamp", + ), + "airTemperature": FieldMapping( + source_name="airTemperature", + output_name="temperature", dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Temperature in Celsius' + transform_fn=lambda x: x - 273.15 if x > 100 else x, + description="Temperature in Celsius", ), - 'pressure': FieldMapping( - source_name='pressure', - output_name='pressure', + "pressure": FieldMapping( + source_name="pressure", + output_name="pressure", dtype=float, - description='Pressure in Pa' + description="Pressure in Pa", ), - 'height': FieldMapping( - source_name='height', - output_name='height', + "height": FieldMapping( + source_name="height", + output_name="height", dtype=float, - description='Height above sealevel in m' + description="Height above sealevel in m", ), - 'dewpointTemperature': FieldMapping( - source_name='dewpointTemperature', - output_name='dew_point', + "dewpointTemperature": FieldMapping( + source_name="dewpointTemperature", + output_name="dew_point", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Dew point in Celsius' + description="Dew point in Celsius", ), - 'windU': FieldMapping( - source_name='windU', - output_name='u_wind', + "windU": FieldMapping( + source_name="windU", + output_name="u_wind", dtype=float, - description='U-component wind (m/s)' + description="U-component wind (m/s)", ), - 'windV': FieldMapping( - source_name='windV', - output_name='v_wind', + "windV": FieldMapping( + source_name="windV", + output_name="v_wind", dtype=float, - description='V-component wind (m/s)' + description="V-component wind (m/s)", ), - 'stationId': FieldMapping( - source_name='stationId', - output_name='station_id', + "stationId": FieldMapping( + source_name="stationId", + output_name="station_id", dtype=str, required=False, - description='Station identifier' - ) + description="Station identifier", + ), } - + def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') + return pd.Timestamp(value, unit="s") elif isinstance(value, str): return pd.Timestamp(value) else: @@ -86,61 +86,61 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: class MockCRISSchema(CRIS_schema): """Mock CrIS schema with proper FieldMapping dtypes.""" - + def _build_mappings(self): self.field_mappings = { - 'latitude': FieldMapping( - source_name='latitude', - output_name='LAT', + "latitude": FieldMapping( + source_name="latitude", + output_name="LAT", dtype=float, - description='Satellite latitude' + description="Satellite latitude", ), - 'longitude': FieldMapping( - source_name='longitude', - output_name='LON', + "longitude": FieldMapping( + source_name="longitude", + output_name="LON", dtype=float, - description='Satellite longitude' + description="Satellite longitude", ), - 'obsTime': FieldMapping( - source_name='obsTime', - output_name='OBS_TIMESTAMP', + "obsTime": FieldMapping( + source_name="obsTime", + output_name="OBS_TIMESTAMP", dtype=object, transform_fn=self._convert_timestamp, - description='Observation timestamp' + description="Observation timestamp", ), - 'retrievedTemperature': FieldMapping( - source_name='retrievedTemperature', - output_name='temperature', + "retrievedTemperature": FieldMapping( + source_name="retrievedTemperature", + output_name="temperature", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description='Retrieved temperature in Celsius' + description="Retrieved temperature in Celsius", ), - 'retrievedPressure': FieldMapping( - source_name='retrievedPressure', - output_name='pressure', + "retrievedPressure": FieldMapping( + source_name="retrievedPressure", + output_name="pressure", dtype=float, - description='Retrieved pressure in Pa' + description="Retrieved pressure in Pa", ), - 'sensorZenithAngle': FieldMapping( - source_name='sensorZenithAngle', - output_name='sensor_zenith_angle', + "sensorZenithAngle": FieldMapping( + source_name="sensorZenithAngle", + output_name="sensor_zenith_angle", dtype=float, required=False, - description='Sensor zenith angle' + description="Sensor zenith angle", ), - 'qualityFlags': FieldMapping( - source_name='qualityFlags', - output_name='qc_flag', + "qualityFlags": FieldMapping( + source_name="qualityFlags", + output_name="qc_flag", dtype=int, - description='Quality control flags' - ) + description="Quality control flags", + ), } - + def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): - return pd.Timestamp(value, unit='s') + return pd.Timestamp(value, unit="s") elif isinstance(value, str): return pd.Timestamp(value) else: - return pd.Timestamp(value) \ No newline at end of file + return pd.Timestamp(value) diff --git a/tests/bufr_process/test_cris.py b/tests/bufr_process/test_cris.py index 0380ed88..bf0a024a 100644 --- a/tests/bufr_process/test_cris.py +++ b/tests/bufr_process/test_cris.py @@ -5,23 +5,23 @@ class TestCRISSchema: """Test CrIS schema implementation.""" - + @pytest.fixture def cris_schema(self): return MockCRISSchema() - + def test_schema_creation(self, cris_schema): """Test CrIS schema initialization.""" assert cris_schema.source_name == "CrIS" assert len(cris_schema.field_mappings) > 0 - + def test_map_observation(self, cris_schema, sample_cris_data): """Test CrIS observation mapping.""" mapped = cris_schema.map_observation(sample_cris_data) - - assert mapped['LAT'] == 30.0 - assert mapped['LON'] == -100.0 - assert isinstance(mapped['OBS_TIMESTAMP'], pd.Timestamp) - assert mapped['temperature'] == pytest.approx(6.85) # 280K to C - assert mapped['pressure'] == 85000.0 - assert mapped['qc_flag'] == 1 \ No newline at end of file + + assert mapped["LAT"] == 30.0 + assert mapped["LON"] == -100.0 + assert isinstance(mapped["OBS_TIMESTAMP"], pd.Timestamp) + assert mapped["temperature"] == pytest.approx(6.85) # 280K to C + assert mapped["pressure"] == 85000.0 + assert mapped["qc_flag"] == 1 diff --git a/tests/bufr_process/test_dataloader.py b/tests/bufr_process/test_dataloader.py index cde83f94..d61d5135 100644 --- a/tests/bufr_process/test_dataloader.py +++ b/tests/bufr_process/test_dataloader.py @@ -9,101 +9,96 @@ class TestBUFRDataLoader: """Test BUFR_dataloader class.""" - + def test_dataloader_initialization_with_schema(self): """Test dataloader initialization with explicit schema.""" - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - - assert loader.schema_name == 'ADPUPA' + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: + loader = BUFR_dataloader(f.name, schema_name="ADPUPA") + + assert loader.schema_name == "ADPUPA" assert isinstance(loader.schema, ADPUPA_schema) assert isinstance(loader.processor, BUFR_processor) - + def test_dataloader_initialization_infer_schema(self): """Test dataloader initialization with schema inference.""" test_cases = [ - ('test_adpupa.bufr', 'ADPUPA'), - ('test_CRIS_data.bufr', 'CrIS'), - ('unknown_file.bufr', 'ADPUPA') # Default case + ("test_adpupa.bufr", "ADPUPA"), + ("test_CRIS_data.bufr", "CrIS"), + ("unknown_file.bufr", "ADPUPA"), # Default case ] - + for filename, expected_schema in test_cases: with tempfile.NamedTemporaryFile(suffix=filename) as f: loader = BUFR_dataloader(f.name) assert loader.schema_name == expected_schema - + def test_dataloader_initialization_invalid_schema(self): """Test dataloader initialization with invalid schema.""" - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: with pytest.raises(ValueError, match='Unknown schema "INVALID"'): - BUFR_dataloader(f.name, schema_name='INVALID') - - @patch.object(BUFR_processor, 'process_files_to_dataframe') + BUFR_dataloader(f.name, schema_name="INVALID") + + @patch.object(BUFR_processor, "process_files_to_dataframe") def test_to_dataframe(self, mock_process): """Test to_dataframe method.""" - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], - 'LAT': [45.0], - 'LON': [-120.0] - }) + mock_df = pd.DataFrame( + {"OBS_TIMESTAMP": [pd.Timestamp("2023-01-01T12:00:00")], "LAT": [45.0], "LON": [-120.0]} + ) mock_process.return_value = mock_df - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: + loader = BUFR_dataloader(f.name, schema_name="ADPUPA") df = loader.to_dataframe() - + assert len(df) == 1 mock_process.assert_called_once_with(str(Path(f.name))) - - @patch.object(BUFR_processor, 'process_files_to_xarray') + + @patch.object(BUFR_processor, "process_files_to_xarray") def test_to_xarray(self, mock_process): """Test to_xarray method.""" - mock_ds = xr.Dataset({ - 'temperature': (['obs'], [20.0]), - 'pressure': (['obs'], [101325.0]) - }, coords={ - 'obs': [0], - 'time': ('obs', [pd.Timestamp('2023-01-01T12:00:00')]), - 'lat': ('obs', [45.0]), - 'lon': ('obs', [-120.0]) - }) + mock_ds = xr.Dataset( + {"temperature": (["obs"], [20.0]), "pressure": (["obs"], [101325.0])}, + coords={ + "obs": [0], + "time": ("obs", [pd.Timestamp("2023-01-01T12:00:00")]), + "lat": ("obs", [45.0]), + "lon": ("obs", [-120.0]), + }, + ) mock_process.return_value = mock_ds - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: + loader = BUFR_dataloader(f.name, schema_name="ADPUPA") ds = loader.to_xarray() - - assert 'temperature' in ds.data_vars + + assert "temperature" in ds.data_vars mock_process.assert_called_once_with(str(Path(f.name))) - - @patch.object(BUFR_processor, 'process_files_to_parquet') + + @patch.object(BUFR_processor, "process_files_to_parquet") def test_to_parquet(self, mock_process, tmp_path): """Test to_parquet method.""" output_file = tmp_path / "test_output.parquet" - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: + loader = BUFR_dataloader(f.name, schema_name="ADPUPA") loader.to_parquet(str(output_file)) - + mock_process.assert_called_once_with(str(Path(f.name)), str(output_file)) - - @patch.object(BUFR_processor, 'decoder_bufr_files') + + @patch.object(BUFR_processor, "decoder_bufr_files") def test_iterator(self, mock_decode): """Test dataloader iterator.""" - mock_messages = [ - {'lat': 45.0, 'lon': -120.0}, - {'lat': 46.0, 'lon': -121.0} - ] + mock_messages = [{"lat": 45.0, "lon": -120.0}, {"lat": 46.0, "lon": -121.0}] mock_decode.return_value = mock_messages - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: - loader = BUFR_dataloader(f.name, schema_name='ADPUPA') - + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: + loader = BUFR_dataloader(f.name, schema_name="ADPUPA") + # Mock the schema's map_observation to return simple data - loader.schema.map_observation = lambda x: {'LAT': x['lat'], 'LON': x['lon']} - + loader.schema.map_observation = lambda x: {"LAT": x["lat"], "LON": x["lon"]} + observations = list(loader) - + assert len(observations) == 2 - assert observations[0]['LAT'] == 45.0 - assert observations[1]['LAT'] == 46.0 \ No newline at end of file + assert observations[0]["LAT"] == 45.0 + assert observations[1]["LAT"] == 46.0 diff --git a/tests/bufr_process/test_field_mapping.py b/tests/bufr_process/test_field_mapping.py index dfa50dd8..c63617d0 100644 --- a/tests/bufr_process/test_field_mapping.py +++ b/tests/bufr_process/test_field_mapping.py @@ -5,53 +5,45 @@ class TestFieldMapping: """Test FieldMapping dataclass.""" - + def test_field_mapping_creation(self): """Test FieldMapping initialization.""" mapping = FieldMapping( source_name="temperature", output_name="temp", dtype=float, - description="Temperature field" + description="Temperature field", ) - + assert mapping.source_name == "temperature" assert mapping.output_name == "temp" assert mapping.dtype == float assert mapping.description == "Temperature field" assert mapping.required is True assert mapping.transform_fn is None - + def test_field_mapping_apply_no_transform(self): """Test apply method without transformation function.""" - mapping = FieldMapping( - source_name="pressure", - output_name="pres", - dtype=float - ) - + mapping = FieldMapping(source_name="pressure", output_name="pres", dtype=float) + result = mapping.apply(1013.25) assert result == 1013.25 - + def test_field_mapping_apply_with_transform(self): """Test apply method with transformation function.""" mapping = FieldMapping( source_name="temp_k", output_name="temp_c", dtype=float, - transform_fn=lambda x: x - 273.15 + transform_fn=lambda x: x - 273.15, ) - + result = mapping.apply(300.0) assert result == pytest.approx(26.85) - + def test_field_mapping_apply_none_value(self): """Test apply method with None value.""" - mapping = FieldMapping( - source_name="missing", - output_name="missing_out", - dtype=float - ) - + mapping = FieldMapping(source_name="missing", output_name="missing_out", dtype=float) + result = mapping.apply(None) - assert result is None \ No newline at end of file + assert result is None diff --git a/tests/bufr_process/test_integration.py b/tests/bufr_process/test_integration.py index 3842f718..c8f6ed52 100644 --- a/tests/bufr_process/test_integration.py +++ b/tests/bufr_process/test_integration.py @@ -1,24 +1,30 @@ import pytest import tempfile -from graph_weather.data.bufr_process import BUFR_dataloader, DataSourceSchema, ADPUPA_schema, CRIS_schema, BUFR_processor +from graph_weather.data.bufr_process import ( + BUFR_dataloader, + DataSourceSchema, + ADPUPA_schema, + CRIS_schema, + BUFR_processor, +) class TestIntegration: """Integration tests for the complete pipeline.""" - + def test_schema_registry_completeness(self): """Test that all schemas in registry can be instantiated.""" for schema_name, schema_class in BUFR_dataloader.SCHEMA_REGISTRY.items(): schema = schema_class() assert isinstance(schema, DataSourceSchema) assert schema.source_name == schema_name - + def test_end_to_end_mock_processing(self): """Test complete mock processing pipeline.""" - with tempfile.NamedTemporaryFile(suffix='_adpupa.bufr') as f: - loader = BUFR_dataloader(f.name) - - assert loader.schema_name == 'ADPUPA' + with tempfile.NamedTemporaryFile(suffix="_adpupa.bufr") as f: + loader = BUFR_dataloader(f.name) + + assert loader.schema_name == "ADPUPA" assert isinstance(loader.schema, ADPUPA_schema) assert isinstance(loader.processor, BUFR_processor) - assert loader.processor.schema == loader.schema \ No newline at end of file + assert loader.processor.schema == loader.schema diff --git a/tests/bufr_process/test_nnja_schema.py b/tests/bufr_process/test_nnja_schema.py index 146e64b8..6d290708 100644 --- a/tests/bufr_process/test_nnja_schema.py +++ b/tests/bufr_process/test_nnja_schema.py @@ -5,39 +5,39 @@ class TestNNJASchema: """Test NNJA_Schema class.""" - + def test_schema_structure(self): """Test schema has expected structure.""" - assert 'OBS_TIMESTAMP' in NNJA_Schema.COORDINATES - assert 'LAT' in NNJA_Schema.COORDINATES - assert 'LON' in NNJA_Schema.COORDINATES - assert 'temperature' in NNJA_Schema.VARIABLES - assert 'pressure' in NNJA_Schema.VARIABLES - + assert "OBS_TIMESTAMP" in NNJA_Schema.COORDINATES + assert "LAT" in NNJA_Schema.COORDINATES + assert "LON" in NNJA_Schema.COORDINATES + assert "temperature" in NNJA_Schema.VARIABLES + assert "pressure" in NNJA_Schema.VARIABLES + def test_to_xarray_schema(self): """Test schema combination.""" full_schema = NNJA_Schema.to_xarray_schema() - assert 'OBS_TIMESTAMP' in full_schema - assert 'temperature' in full_schema - assert 'source' in full_schema - + assert "OBS_TIMESTAMP" in full_schema + assert "temperature" in full_schema + assert "source" in full_schema + def test_get_coordinate_names(self): """Test coordinate names retrieval.""" coords = NNJA_Schema.get_coordinate_names() - expected_coords = ['OBS_TIMESTAMP', 'LAT', 'LON'] + expected_coords = ["OBS_TIMESTAMP", "LAT", "LON"] assert set(coords) == set(expected_coords) - + def test_validate_data(self): """Test data validation.""" valid_data = { - 'OBS_TIMESTAMP': np.array(['2023-01-01'], dtype='datetime64[ns]'), - 'LAT': np.array([45.0], dtype='float32'), - 'LON': np.array([-120.0], dtype='float32') + "OBS_TIMESTAMP": np.array(["2023-01-01"], dtype="datetime64[ns]"), + "LAT": np.array([45.0], dtype="float32"), + "LON": np.array([-120.0], dtype="float32"), } assert NNJA_Schema.validate_data(valid_data) is True - + invalid_data = { - 'LAT': np.array([45.0], dtype='float32'), - 'LON': np.array([-120.0], dtype='float32') + "LAT": np.array([45.0], dtype="float32"), + "LON": np.array([-120.0], dtype="float32"), } - assert NNJA_Schema.validate_data(invalid_data) is False \ No newline at end of file + assert NNJA_Schema.validate_data(invalid_data) is False diff --git a/tests/bufr_process/test_processor.py b/tests/bufr_process/test_processor.py index 32c23465..8342f25b 100644 --- a/tests/bufr_process/test_processor.py +++ b/tests/bufr_process/test_processor.py @@ -3,33 +3,33 @@ import xarray as xr import tempfile from pathlib import Path -from unittest.mock import patch , Mock +from unittest.mock import patch, Mock from graph_weather.data.bufr_process import BUFR_processor class TestBUFRProcessor: """Test BUFR_processor class.""" - + @pytest.fixture def bufr_processor(self, mock_schema): return BUFR_processor(mock_schema) - + def test_processor_initialization(self, mock_schema): """Test processor initialization.""" processor = BUFR_processor(mock_schema) assert processor.schema == mock_schema - + def test_processor_invalid_schema(self): """Test processor with invalid schema.""" with pytest.raises(TypeError): BUFR_processor("invalid_schema") - - @patch('graph_weather.data.bufr_process.eccodes', create=True) + + @patch("graph_weather.data.bufr_process.eccodes", create=True) def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): """Test successful BUFR file decoding.""" test_file = tmp_path / "test.bufr" test_file.write_bytes(b"test bufr content") - + # Mock eccodes behavior mock_bufr_id = Mock() mock_eccodes.codes_bufr_new_from_file.side_effect = [mock_bufr_id, None] @@ -38,91 +38,98 @@ def test_decode_bufr_file_success(self, mock_eccodes, bufr_processor, tmp_path): mock_eccodes.codes_bufr_keys_iterator_next.side_effect = [True, False] mock_eccodes.codes_bufr_keys_iterator_get_name.return_value = "test_key" mock_eccodes.codes_get_string.return_value = "test_value" - + # Use the correct method name - decoder_bufr_files messages = bufr_processor.decoder_bufr_files(str(test_file)) - + assert len(messages) == 1 assert messages[0]["test_key"] == "test_value" - + def test_decode_bufr_file_not_found(self, bufr_processor, tmp_path): """Test BUFR file not found.""" with pytest.raises(FileNotFoundError): bufr_processor.decoder_bufr_files(str(tmp_path / "nonexistent.bufr")) - - @patch.object(BUFR_processor, 'decoder_bufr_files') + + @patch.object(BUFR_processor, "decoder_bufr_files") def test_process_files_to_dataframe(self, mock_decode, bufr_processor, mock_schema): """Test processing BUFR file to DataFrame.""" # Mock decoded messages mock_messages = [ - {'latitude': 45.0, 'longitude': -120.0, 'obsTime': '2023-01-01T12:00:00'}, - {'latitude': 46.0, 'longitude': -121.0, 'obsTime': '2023-01-01T12:30:00'} + {"latitude": 45.0, "longitude": -120.0, "obsTime": "2023-01-01T12:00:00"}, + {"latitude": 46.0, "longitude": -121.0, "obsTime": "2023-01-01T12:30:00"}, ] mock_decode.return_value = mock_messages - + # Mock schema mapping mock_schema.map_observation.side_effect = [ - {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:00:00'), 'LAT': 45.0, 'LON': -120.0}, - {'OBS_TIMESTAMP': pd.Timestamp('2023-01-01T12:30:00'), 'LAT': 46.0, 'LON': -121.0} + {"OBS_TIMESTAMP": pd.Timestamp("2023-01-01T12:00:00"), "LAT": 45.0, "LON": -120.0}, + {"OBS_TIMESTAMP": pd.Timestamp("2023-01-01T12:30:00"), "LAT": 46.0, "LON": -121.0}, ] - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: df = bufr_processor.process_files_to_dataframe(f.name) - + assert len(df) == 2 - assert 'OBS_TIMESTAMP' in df.columns - assert 'LAT' in df.columns - assert 'LON' in df.columns - assert df['LAT'].iloc[0] == 45.0 - - @patch.object(BUFR_processor, 'decoder_bufr_files') + assert "OBS_TIMESTAMP" in df.columns + assert "LAT" in df.columns + assert "LON" in df.columns + assert df["LAT"].iloc[0] == 45.0 + + @patch.object(BUFR_processor, "decoder_bufr_files") def test_process_files_to_dataframe_empty(self, mock_decode, bufr_processor): """Test processing BUFR file with no valid observations.""" mock_decode.return_value = [] # No messages - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: df = bufr_processor.process_files_to_dataframe(f.name) - + assert df.empty - - @patch.object(BUFR_processor, 'process_files_to_dataframe') + + @patch.object(BUFR_processor, "process_files_to_dataframe") def test_process_files_to_xarray(self, mock_process, bufr_processor): """Test processing BUFR file to xarray Dataset.""" # Mock DataFrame - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00'), pd.Timestamp('2023-01-01T12:30:00')], - 'LAT': [45.0, 46.0], - 'LON': [-120.0, -121.0], - 'temperature': [20.0, 19.5], - 'pressure': [101325.0, 101300.0] - }) + mock_df = pd.DataFrame( + { + "OBS_TIMESTAMP": [ + pd.Timestamp("2023-01-01T12:00:00"), + pd.Timestamp("2023-01-01T12:30:00"), + ], + "LAT": [45.0, 46.0], + "LON": [-120.0, -121.0], + "temperature": [20.0, 19.5], + "pressure": [101325.0, 101300.0], + } + ) mock_process.return_value = mock_df - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: ds = bufr_processor.process_files_to_xarray(f.name) - - assert 'temperature' in ds.data_vars - assert 'pressure' in ds.data_vars - assert 'time' in ds.coords - assert 'lat' in ds.coords - assert 'lon' in ds.coords - assert ds.attrs['source'] == 'TEST' - - @patch.object(BUFR_processor, 'process_files_to_dataframe') + + assert "temperature" in ds.data_vars + assert "pressure" in ds.data_vars + assert "time" in ds.coords + assert "lat" in ds.coords + assert "lon" in ds.coords + assert ds.attrs["source"] == "TEST" + + @patch.object(BUFR_processor, "process_files_to_dataframe") def test_process_files_to_parquet(self, mock_process, bufr_processor, tmp_path): """Test processing BUFR file to Parquet.""" # Mock DataFrame - mock_df = pd.DataFrame({ - 'OBS_TIMESTAMP': [pd.Timestamp('2023-01-01T12:00:00')], - 'LAT': [45.0], - 'LON': [-120.0], - 'temperature': [20.0] - }) + mock_df = pd.DataFrame( + { + "OBS_TIMESTAMP": [pd.Timestamp("2023-01-01T12:00:00")], + "LAT": [45.0], + "LON": [-120.0], + "temperature": [20.0], + } + ) mock_process.return_value = mock_df - + output_file = tmp_path / "output.parquet" - - with tempfile.NamedTemporaryFile(suffix='.bufr') as f: + + with tempfile.NamedTemporaryFile(suffix=".bufr") as f: bufr_processor.process_files_to_parquet(f.name, str(output_file)) - - assert output_file.exists() \ No newline at end of file + + assert output_file.exists() From 65ee1f8bbf2773841d28238da6278822bfefa6e2 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Tue, 4 Nov 2025 21:13:15 +0530 Subject: [PATCH 09/13] chore : added more vars for adpupa --- graph_weather/data/bufr_process.py | 248 +++++++++++++++++++++++++++-- 1 file changed, 237 insertions(+), 11 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 818fd328..516b0037 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -151,9 +151,15 @@ def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: class ADPUPA_schema(DataSourceSchema): - """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI.""" + """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI. + + Includes mandatory pressure levels: 1000, 925, 850, 700, 500, 300, 200, 100 hPa + """ source_name = "ADPUPA" + + # Standard mandatory pressure levels in Pa + MANDATORY_LEVELS = [100, 200, 300, 500, 700, 850, 925, 1000] def _build_mappings(self): self.field_mappings = { @@ -176,52 +182,272 @@ def _build_mappings(self): transform_fn=self._convert_timestamp, description="Observation timestamp", ), + + # ===== STATION METADATA ===== + "WMOB": FieldMapping( + source_name="WMOB", + output_name="wmo_block_number", + dtype=str, + required=False, + description="WMO block number", + ), + "WMOS": FieldMapping( + source_name="WMOS", + output_name="wmo_station_number", + dtype=str, + required=False, + description="WMO station number", + ), + "WMOR": FieldMapping( + source_name="WMOR", + output_name="wmo_region", + dtype=int, + required=False, + description="WMO Region number/geographical area", + ), + "UASID.RPID": FieldMapping( + source_name="UASID.RPID", + output_name="report_id", + dtype=str, + required=False, + description="Report identifier", + ), + "UASID.SELV": FieldMapping( + source_name="UASID.SELV", + output_name="station_elevation", + dtype=float, + required=False, + description="Height of station (m)", + ), + "stationId": FieldMapping( + source_name="stationId", + output_name="station_id", + dtype=str, + required=False, + description="Station identifier", + ), + + # ===== SURFACE/SINGLE LEVEL DATA ===== "airTemperature": FieldMapping( source_name="airTemperature", output_name="temperature", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description="Temperature in Celsius", + required=False, + description="Surface temperature in Celsius", ), "pressure": FieldMapping( source_name="pressure", output_name="pressure", dtype=float, - description="Pressure in Pa", + required=False, + description="Surface pressure in Pa", ), "height": FieldMapping( source_name="height", output_name="height", dtype=float, - description="Height above sealevel in m", + required=False, + description="Height above sea level in m", ), "dewpointTemperature": FieldMapping( source_name="dewpointTemperature", output_name="dew_point", dtype=float, transform_fn=lambda x: x - 273.15 if x > 100 else x, - description="Dew point in Celsius", + required=False, + description="Surface dew point in Celsius", ), "windU": FieldMapping( source_name="windU", output_name="u_wind", dtype=float, - description="U-component wind (m/s)", + required=False, + description="Surface U-component wind (m/s)", ), "windV": FieldMapping( source_name="windV", output_name="v_wind", dtype=float, - description="V-component wind (m/s)", + required=False, + description="Surface V-component wind (m/s)", ), - "stationId": FieldMapping( - source_name="stationId", - output_name="station_id", + + # ===== ADDITIONAL METEOROLOGICAL DATA ===== + "UASDG.SST1": FieldMapping( + source_name="UASDG.SST1", + output_name="sea_surface_temp", + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + required=False, + description="Sea/water temperature in Celsius", + ), + "UAADF.STBS5": FieldMapping( + source_name="UAADF.STBS5", + output_name="showalter_index", + dtype=float, + required=False, + description="Modified Showalter stability index", + ), + "UAADF.MWDL": FieldMapping( + source_name="UAADF.MWDL", + output_name="mean_wind_dir_low", + dtype=float, + required=False, + description="Mean wind direction for surface - 1500m (degrees)", + ), + "UAADF.MWSL": FieldMapping( + source_name="UAADF.MWSL", + output_name="mean_wind_speed_low", + dtype=float, + required=False, + description="Mean wind speed for surface - 1500m (m/s)", + ), + "UAADF.MWDH": FieldMapping( + source_name="UAADF.MWDH", + output_name="mean_wind_dir_high", + dtype=float, + required=False, + description="Mean wind direction for 1500-3000m (degrees)", + ), + "UAADF.MWSH": FieldMapping( + source_name="UAADF.MWSH", + output_name="mean_wind_speed_high", + dtype=float, + required=False, + description="Mean wind speed for 1500-3000m (m/s)", + ), + + "MSG_TYPE": FieldMapping( + source_name="MSG_TYPE", + output_name="message_type", dtype=str, required=False, - description="Station identifier", + description="Source message type", + ), + "MSG_DATE": FieldMapping( + source_name="MSG_DATE", + output_name="message_date", + dtype=object, + transform_fn=self._convert_timestamp, + required=False, + description="Message valid timestamp", + ), + "OBS_DATE": FieldMapping( + source_name="OBS_DATE", + output_name="obs_date", + dtype=object, + transform_fn=self._convert_timestamp, + required=False, + description="Date of the observation", + ), + "SRC_FILENAME": FieldMapping( + source_name="SRC_FILENAME", + output_name="source_filename", + dtype=str, + required=False, + description="Source filename", ), } + + for level_hpa in self.MANDATORY_LEVELS: + level_pa = level_hpa * 100 # Convert hPa to Pa for BUFR field names + + self.field_mappings[f"TMDB_PRLC{level_pa}"] = FieldMapping( + source_name=f"TMDB_PRLC{level_pa}", + output_name=f"temperature_{level_hpa}hPa", + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + required=False, + description=f"Temperature at {level_hpa} hPa in Celsius", + ) + + # Dewpoint at pressure level + self.field_mappings[f"TMDP_PRLC{level_pa}"] = FieldMapping( + source_name=f"TMDP_PRLC{level_pa}", + output_name=f"dew_point_{level_hpa}hPa", + dtype=float, + transform_fn=lambda x: x - 273.15 if x > 100 else x, + required=False, + description=f"Dewpoint temperature at {level_hpa} hPa in Celsius", + ) + + # Wind speed at pressure level + self.field_mappings[f"WSPD_PRLC{level_pa}"] = FieldMapping( + source_name=f"WSPD_PRLC{level_pa}", + output_name=f"wind_speed_{level_hpa}hPa", + dtype=float, + required=False, + description=f"Wind speed at {level_hpa} hPa (m/s)", + ) + + # Wind direction at pressure level + self.field_mappings[f"WDIR_PRLC{level_pa}"] = FieldMapping( + source_name=f"WDIR_PRLC{level_pa}", + output_name=f"wind_direction_{level_hpa}hPa", + dtype=float, + required=False, + description=f"Wind direction at {level_hpa} hPa (degrees)", + ) + + # Geopotential at pressure level + self.field_mappings[f"GP10_PRLC{level_pa}"] = FieldMapping( + source_name=f"GP10_PRLC{level_pa}", + output_name=f"geopotential_{level_hpa}hPa", + dtype=float, + required=False, + description=f"Geopotential at {level_hpa} hPa (m²/s²)", + ) + + # ===== QUALITY CONTROL FLAGS ===== + self.field_mappings[f"QMAT_PRLC{level_pa}"] = FieldMapping( + source_name=f"QMAT_PRLC{level_pa}", + output_name=f"qc_temperature_{level_hpa}hPa", + dtype=int, + required=False, + description=f"QC flag for temperature at {level_hpa} hPa", + ) + + self.field_mappings[f"QMDD_PRLC{level_pa}"] = FieldMapping( + source_name=f"QMDD_PRLC{level_pa}", + output_name=f"qc_moisture_{level_hpa}hPa", + dtype=int, + required=False, + description=f"QC flag for moisture at {level_hpa} hPa", + ) + + self.field_mappings[f"QMWN_PRLC{level_pa}"] = FieldMapping( + source_name=f"QMWN_PRLC{level_pa}", + output_name=f"qc_wind_{level_hpa}hPa", + dtype=int, + required=False, + description=f"QC flag for wind at {level_hpa} hPa", + ) + + self.field_mappings[f"QMGP_PRLC{level_pa}"] = FieldMapping( + source_name=f"QMGP_PRLC{level_pa}", + output_name=f"qc_geopotential_{level_hpa}hPa", + dtype=int, + required=False, + description=f"QC flag for geopotential at {level_hpa} hPa", + ) + + self.field_mappings[f"QMPR_PRLC{level_pa}"] = FieldMapping( + source_name=f"QMPR_PRLC{level_pa}", + output_name=f"qc_pressure_{level_hpa}hPa", + dtype=int, + required=False, + description=f"QC flag for pressure at {level_hpa} hPa", + ) + + # Vertical sounding significance + self.field_mappings[f"VSIG_PRLC{level_pa}"] = FieldMapping( + source_name=f"VSIG_PRLC{level_pa}", + output_name=f"sounding_significance_{level_hpa}hPa", + dtype=int, + required=False, + description=f"Vertical sounding significance at {level_hpa} hPa", + ) def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" From 62eae0e1f1ef7e60ecc95484ce892904effb73d7 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 10 Nov 2025 18:31:04 +0530 Subject: [PATCH 10/13] chore : updated cris schema --- graph_weather/data/bufr_process.py | 187 ++++++++++++++++++++++++++--- 1 file changed, 169 insertions(+), 18 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 516b0037..8b35b85c 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -458,7 +458,6 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: else: return pd.Timestamp(value) - class CRIS_schema(DataSourceSchema): """CrIS (satellite hyperspectral) BUFR schema mapping to NNJA-AI.""" @@ -485,34 +484,187 @@ def _build_mappings(self): transform_fn=self._convert_timestamp, description="Observation timestamp", ), - "retrievedTemperature": FieldMapping( - source_name="retrievedTemperature", - output_name="temperature", - dtype=float, - transform_fn=lambda x: x - 273.15 if x > 100 else x, - description="Retrieved temperature in Celsius", + "obsDate": FieldMapping( + source_name="obsDate", + output_name="OBS_DATE", + dtype=object, + description="Date of the observation", ), - "retrievedPressure": FieldMapping( - source_name="retrievedPressure", - output_name="pressure", - dtype=float, - description="Retrieved pressure in Pa", + "satelliteId": FieldMapping( + source_name="satelliteId", + output_name="SAID", + dtype=int, + description="Satellite identifier", ), "sensorZenithAngle": FieldMapping( source_name="sensorZenithAngle", - output_name="sensor_zenith_angle", + output_name="SAZA", + dtype=float, + required=False, + description="Satellite zenith angle", + ), + "solarZenithAngle": FieldMapping( + source_name="solarZenithAngle", + output_name="SOZA", + dtype=float, + required=False, + description="Solar zenith angle", + ), + "solarAzimuth": FieldMapping( + source_name="solarAzimuth", + output_name="SOLAZI", + dtype=float, + required=False, + description="Solar azimuth angle", + ), + "bearingAzimuth": FieldMapping( + source_name="bearingAzimuth", + output_name="BEARAZ", + dtype=float, + required=False, + description="Bearing or azimuth", + ), + "orbitNumber": FieldMapping( + source_name="orbitNumber", + output_name="ORBN", + dtype=int, + required=False, + description="Orbit number", + ), + "scanLineNumber": FieldMapping( + source_name="scanLineNumber", + output_name="SLNM", + dtype=int, + required=False, + description="Scan line number", + ), + "fieldOfRegardNumber": FieldMapping( + source_name="fieldOfRegardNumber", + output_name="FORN", + dtype=int, + required=False, + description="Field of regard number", + ), + "fieldOfViewNumber": FieldMapping( + source_name="fieldOfViewNumber", + output_name="FOVN", + dtype=int, + required=False, + description="Field of view number", + ), + "heightAboveSurface": FieldMapping( + source_name="heightAboveSurface", + output_name="HMSL", + dtype=float, + required=False, + description="Height or altitude above mean sea level", + ), + "heightOfLandSurface": FieldMapping( + source_name="heightOfLandSurface", + output_name="HOLS", + dtype=float, + required=False, + description="Height of land surface", + ), + "totalCloudCover": FieldMapping( + source_name="totalCloudCover", + output_name="TOCC", + dtype=float, + required=False, + description="Cloud cover (total)", + ), + "cloudTopHeight": FieldMapping( + source_name="cloudTopHeight", + output_name="HOCT", + dtype=float, + required=False, + description="Height of top of cloud", + ), + "landFraction": FieldMapping( + source_name="landFraction", + output_name="ALFR", dtype=float, required=False, - description="Sensor zenith angle", + description="Land fraction", + ), + "landSeaQualifier": FieldMapping( + source_name="landSeaQualifier", + output_name="LSQL", + dtype=int, + required=False, + description="Land/sea qualifier", ), "qualityFlags": FieldMapping( source_name="qualityFlags", - output_name="qc_flag", + output_name="NSQF", + dtype=int, + required=False, + description="Scan-level quality flags", + ), + "radianceTypeFlags": FieldMapping( + source_name="radianceTypeFlags", + output_name="RDTF", + dtype=int, + required=False, + description="Radiance type flags", + ), + "geolocationQuality": FieldMapping( + source_name="geolocationQuality", + output_name="NGQI", + dtype=int, + required=False, + description="Geolocation quality", + ), + "orbitQualifier": FieldMapping( + source_name="orbitQualifier", + output_name="STKO", dtype=int, - description="Quality control flags", + required=False, + description="Ascending/descending orbit qualifier", + ), + # Channel radiance mappings - you can add specific channels as needed + "channelRadiances": FieldMapping( + source_name="channelRadiances", + output_name="CRCHNM_SRAD01", + dtype=object, + required=False, + description="CrIS channel radiances array", + ), + "guardChannelData": FieldMapping( + source_name="guardChannelData", + output_name="GCRCHN", + dtype=object, + required=False, + description="NPP CrIS GUARD CHANNEL DATA array", + ), + "viirsSceneData": FieldMapping( + source_name="viirsSceneData", + output_name="CRISCS", + dtype=object, + required=False, + description="CrIS LEVEL 1B VIIRS SINGLE SCENE SEQUENCE DATA array", ), } + # common channels for use case + common_channels = { + "radiance_ch19": "CRCHNM.SRAD01_00019", + "radiance_ch24": "CRCHNM.SRAD01_00024", + "radiance_ch26": "CRCHNM.SRAD01_00026", + "radiance_ch27": "CRCHNM.SRAD01_00027", + "radiance_ch31": "CRCHNM.SRAD01_00031", + "radiance_ch32": "CRCHNM.SRAD01_00032", + } + + for key, source_name in common_channels.items(): + self.field_mappings[key] = FieldMapping( + source_name=source_name, + output_name=source_name.replace(".", "_"), + dtype=float, + required=False, + description=f"Channel radiance for {source_name}", + ) + def _convert_timestamp(self, value: Any) -> pd.Timestamp: """Convert BUFR timestamp to pandas Timestamp.""" if isinstance(value, (int, float)): @@ -521,8 +673,7 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: return pd.Timestamp(value) else: return pd.Timestamp(value) - - + class BUFR_processor: """ Low-level BUFR file decoder. From b76fd574f48aad85d5bf3594aaeb45f9bad57f2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:00:00 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/data/bufr_process.py | 40 ++++++++++++++---------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index 8b35b85c..ead4c117 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -152,12 +152,12 @@ def map_observation(self, bufr_message: Dict[str, Any]) -> Dict[str, Any]: class ADPUPA_schema(DataSourceSchema): """ADPUPA (upper-air radiosonde) BUFR schema mapping to NNJA-AI. - + Includes mandatory pressure levels: 1000, 925, 850, 700, 500, 300, 200, 100 hPa """ source_name = "ADPUPA" - + # Standard mandatory pressure levels in Pa MANDATORY_LEVELS = [100, 200, 300, 500, 700, 850, 925, 1000] @@ -182,7 +182,6 @@ def _build_mappings(self): transform_fn=self._convert_timestamp, description="Observation timestamp", ), - # ===== STATION METADATA ===== "WMOB": FieldMapping( source_name="WMOB", @@ -226,7 +225,6 @@ def _build_mappings(self): required=False, description="Station identifier", ), - # ===== SURFACE/SINGLE LEVEL DATA ===== "airTemperature": FieldMapping( source_name="airTemperature", @@ -272,7 +270,6 @@ def _build_mappings(self): required=False, description="Surface V-component wind (m/s)", ), - # ===== ADDITIONAL METEOROLOGICAL DATA ===== "UASDG.SST1": FieldMapping( source_name="UASDG.SST1", @@ -317,7 +314,6 @@ def _build_mappings(self): required=False, description="Mean wind speed for 1500-3000m (m/s)", ), - "MSG_TYPE": FieldMapping( source_name="MSG_TYPE", output_name="message_type", @@ -349,10 +345,10 @@ def _build_mappings(self): description="Source filename", ), } - + for level_hpa in self.MANDATORY_LEVELS: level_pa = level_hpa * 100 # Convert hPa to Pa for BUFR field names - + self.field_mappings[f"TMDB_PRLC{level_pa}"] = FieldMapping( source_name=f"TMDB_PRLC{level_pa}", output_name=f"temperature_{level_hpa}hPa", @@ -361,7 +357,7 @@ def _build_mappings(self): required=False, description=f"Temperature at {level_hpa} hPa in Celsius", ) - + # Dewpoint at pressure level self.field_mappings[f"TMDP_PRLC{level_pa}"] = FieldMapping( source_name=f"TMDP_PRLC{level_pa}", @@ -371,7 +367,7 @@ def _build_mappings(self): required=False, description=f"Dewpoint temperature at {level_hpa} hPa in Celsius", ) - + # Wind speed at pressure level self.field_mappings[f"WSPD_PRLC{level_pa}"] = FieldMapping( source_name=f"WSPD_PRLC{level_pa}", @@ -380,7 +376,7 @@ def _build_mappings(self): required=False, description=f"Wind speed at {level_hpa} hPa (m/s)", ) - + # Wind direction at pressure level self.field_mappings[f"WDIR_PRLC{level_pa}"] = FieldMapping( source_name=f"WDIR_PRLC{level_pa}", @@ -389,7 +385,7 @@ def _build_mappings(self): required=False, description=f"Wind direction at {level_hpa} hPa (degrees)", ) - + # Geopotential at pressure level self.field_mappings[f"GP10_PRLC{level_pa}"] = FieldMapping( source_name=f"GP10_PRLC{level_pa}", @@ -398,7 +394,7 @@ def _build_mappings(self): required=False, description=f"Geopotential at {level_hpa} hPa (m²/s²)", ) - + # ===== QUALITY CONTROL FLAGS ===== self.field_mappings[f"QMAT_PRLC{level_pa}"] = FieldMapping( source_name=f"QMAT_PRLC{level_pa}", @@ -407,7 +403,7 @@ def _build_mappings(self): required=False, description=f"QC flag for temperature at {level_hpa} hPa", ) - + self.field_mappings[f"QMDD_PRLC{level_pa}"] = FieldMapping( source_name=f"QMDD_PRLC{level_pa}", output_name=f"qc_moisture_{level_hpa}hPa", @@ -415,7 +411,7 @@ def _build_mappings(self): required=False, description=f"QC flag for moisture at {level_hpa} hPa", ) - + self.field_mappings[f"QMWN_PRLC{level_pa}"] = FieldMapping( source_name=f"QMWN_PRLC{level_pa}", output_name=f"qc_wind_{level_hpa}hPa", @@ -423,7 +419,7 @@ def _build_mappings(self): required=False, description=f"QC flag for wind at {level_hpa} hPa", ) - + self.field_mappings[f"QMGP_PRLC{level_pa}"] = FieldMapping( source_name=f"QMGP_PRLC{level_pa}", output_name=f"qc_geopotential_{level_hpa}hPa", @@ -431,7 +427,7 @@ def _build_mappings(self): required=False, description=f"QC flag for geopotential at {level_hpa} hPa", ) - + self.field_mappings[f"QMPR_PRLC{level_pa}"] = FieldMapping( source_name=f"QMPR_PRLC{level_pa}", output_name=f"qc_pressure_{level_hpa}hPa", @@ -439,7 +435,7 @@ def _build_mappings(self): required=False, description=f"QC flag for pressure at {level_hpa} hPa", ) - + # Vertical sounding significance self.field_mappings[f"VSIG_PRLC{level_pa}"] = FieldMapping( source_name=f"VSIG_PRLC{level_pa}", @@ -458,6 +454,7 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: else: return pd.Timestamp(value) + class CRIS_schema(DataSourceSchema): """CrIS (satellite hyperspectral) BUFR schema mapping to NNJA-AI.""" @@ -649,13 +646,13 @@ def _build_mappings(self): # common channels for use case common_channels = { "radiance_ch19": "CRCHNM.SRAD01_00019", - "radiance_ch24": "CRCHNM.SRAD01_00024", + "radiance_ch24": "CRCHNM.SRAD01_00024", "radiance_ch26": "CRCHNM.SRAD01_00026", "radiance_ch27": "CRCHNM.SRAD01_00027", "radiance_ch31": "CRCHNM.SRAD01_00031", "radiance_ch32": "CRCHNM.SRAD01_00032", } - + for key, source_name in common_channels.items(): self.field_mappings[key] = FieldMapping( source_name=source_name, @@ -673,7 +670,8 @@ def _convert_timestamp(self, value: Any) -> pd.Timestamp: return pd.Timestamp(value) else: return pd.Timestamp(value) - + + class BUFR_processor: """ Low-level BUFR file decoder. From 71d9b69e0d4ee65ab8123ee03584b5da6cfd22b3 Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula Date: Mon, 10 Nov 2025 18:38:23 +0530 Subject: [PATCH 12/13] chore : minor changes as requested --- graph_weather/data/bufr_process.py | 9 ++++++--- graph_weather/data/weather_station_reader.py | 10 ---------- tests/bufr_process/conftest.py | 9 +-------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/graph_weather/data/bufr_process.py b/graph_weather/data/bufr_process.py index ead4c117..60c0662c 100644 --- a/graph_weather/data/bufr_process.py +++ b/graph_weather/data/bufr_process.py @@ -857,9 +857,12 @@ def _infer_schema_from_path(self) -> str: elif "atms" in filename: return "ATMS" else: - # Default to ADPUPA for now - logger.warning(f"Could not infer schema from {filename}, defaulting to ADPUPA") - return "ADPUPA" + available_schemas = list(self.SCHEMA_REGISTRY.keys()) + raise ValueError( + f"Could not infer schema from filename '{filename}'. " + f"Available schemas: {available_schemas}. " + f"Please specify schema_name explicitly." + ) def to_dataframe(self) -> pd.DataFrame: """Process BUFR file to DataFrame.""" diff --git a/graph_weather/data/weather_station_reader.py b/graph_weather/data/weather_station_reader.py index a11b20b8..1ad4cd9a 100644 --- a/graph_weather/data/weather_station_reader.py +++ b/graph_weather/data/weather_station_reader.py @@ -26,16 +26,6 @@ ) logger = logging.getLogger("WeatherStationReader") -# Try importing synopticpy, but don't require it -# try: -# from synoptic import Synoptic - -# SYNOPTIC_AVAILABLE = True -# except ImportError: -# SYNOPTIC_AVAILABLE = False -# logger.warning("SynopticPy package not installed, synoptic functionality won't be available") - - class WeatherStationReader: """ The reader for local weather station observations. diff --git a/tests/bufr_process/conftest.py b/tests/bufr_process/conftest.py index db891654..a8941e63 100644 --- a/tests/bufr_process/conftest.py +++ b/tests/bufr_process/conftest.py @@ -1,13 +1,6 @@ import pytest import pandas as pd -import numpy as np -from unittest.mock import Mock, patch -import tempfile -from pathlib import Path -import sys -import os - -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) +from unittest.mock import Mock from graph_weather.data.bufr_process import DataSourceSchema From 7517a364b3155b60521d2b8fce9b89184fa8303d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:09:29 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- graph_weather/data/weather_station_reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/graph_weather/data/weather_station_reader.py b/graph_weather/data/weather_station_reader.py index 1ad4cd9a..fc0cd7f1 100644 --- a/graph_weather/data/weather_station_reader.py +++ b/graph_weather/data/weather_station_reader.py @@ -26,6 +26,7 @@ ) logger = logging.getLogger("WeatherStationReader") + class WeatherStationReader: """ The reader for local weather station observations.