diff --git a/external/fv3fit/fv3fit/reservoir/adapters.py b/external/fv3fit/fv3fit/reservoir/adapters.py index 94425bb76a..2b4b1c3ed7 100644 --- a/external/fv3fit/fv3fit/reservoir/adapters.py +++ b/external/fv3fit/fv3fit/reservoir/adapters.py @@ -1,13 +1,17 @@ from __future__ import annotations +from dataclasses import asdict +import fsspec import numpy as np import os import typing -from typing import Iterable, Hashable, Sequence, Union, Mapping +from typing import Iterable, Hashable, Sequence, Union, Mapping, Optional import xarray as xr +import yaml import fv3fit from fv3fit import Predictor from fv3fit._shared import io +from fv3fit.reservoir.config import ClipZConfig from .model import ( HybridReservoirComputingModel, ReservoirComputingModel, @@ -66,7 +70,10 @@ def output_array_to_ds( ).transpose(*output_dims) def input_dataset_to_arrays( - self, inputs: xr.Dataset, variables: Iterable[Hashable] + self, + inputs: xr.Dataset, + variables: Iterable[Hashable], + clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None, ) -> Sequence[np.ndarray]: # Converts from xr dataset to sequence of variable ndarrays expected by encoder # Make sure the xy dimensions match the rank divider @@ -80,6 +87,10 @@ def input_dataset_to_arrays( da = transposed_inputs[variable] if "z" not in da.dims: da = da.expand_dims("z", axis=-1) + if clip_config is not None and variable in clip_config: + da = da.isel( + z=slice(clip_config[variable].start, clip_config[variable].stop) + ) input_arrs.append(da.values) return input_arrs @@ -118,6 +129,9 @@ def input_overlap(self): def is_hybrid(self): return False + def dump_state(self, path: str): + self.model.reservoir.dump_state(path) + def predict(self, inputs: xr.Dataset) -> xr.Dataset: # inputs arg is not used, but is required by Predictor signature and prog run prediction_arr = self.model.predict() @@ -127,7 +141,7 @@ def predict(self, inputs: xr.Dataset) -> xr.Dataset: def increment_state(self, inputs: xr.Dataset): xy_input_arrs = self.model_adapter.input_dataset_to_arrays( - inputs, self.input_variables + inputs, self.input_variables, ) # x, y, feature dims self.model.increment_state(xy_input_arrs) @@ -159,12 +173,14 @@ def load(cls, path: str) -> "ReservoirDatasetAdapter": @io.register("hybrid-reservoir-adapter") class HybridReservoirDatasetAdapter(Predictor): MODEL_DIR = "hybrid_reservoir_model" + CLIP_CONFIG_FILE = "clip_config.yaml" def __init__( self, model: HybridReservoirComputingModel, input_variables: Iterable[Hashable], output_variables: Iterable[Hashable], + clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None, ) -> None: """Wraps a hybrid reservoir model to take in and return xarray datasets. The initialization args for input and output variables are not used and @@ -183,6 +199,10 @@ def __init__( input_variables=self.input_variables, output_variables=model.output_variables, ) + self.clip_config = clip_config + + def dump_state(self, path: str): + self.model.reservoir.dump_state(path) @property def input_overlap(self): @@ -195,7 +215,7 @@ def is_hybrid(self): def predict(self, inputs: xr.Dataset) -> xr.Dataset: xy_input_arrs = self.model_adapter.input_dataset_to_arrays( - inputs, self.model.hybrid_variables + inputs, self.model.hybrid_variables, clip_config=self.clip_config ) # x, y, feature dims prediction_arr = self.model.predict(xy_input_arrs) @@ -205,7 +225,7 @@ def predict(self, inputs: xr.Dataset) -> xr.Dataset: def increment_state(self, inputs: xr.Dataset): xy_input_arrs = self.model_adapter.input_dataset_to_arrays( - inputs, self.model.input_variables + inputs, self.model.input_variables, clip_config=self.clip_config ) # x, y, feature dims self.model.increment_state(xy_input_arrs) @@ -224,14 +244,30 @@ def get_model_from_subdomain( def dump(self, path): self.model.dump(os.path.join(path, self.MODEL_DIR)) + if self.clip_config is not None: + clip_config_dict = { + var: asdict(var_config) for var, var_config in self.clip_config.items() + } + with fsspec.open(os.path.join(path, self.CLIP_CONFIG_FILE), "w") as f: + yaml.dump(clip_config_dict, f) @classmethod def load(cls, path: str) -> "HybridReservoirDatasetAdapter": model = HybridReservoirComputingModel.load(os.path.join(path, cls.MODEL_DIR)) + try: + with fsspec.open(os.path.join(path, cls.CLIP_CONFIG_FILE), "r") as f: + clip_config_dict = yaml.safe_load(f) + clip_config: Optional[Mapping[Hashable, ClipZConfig]] = { + var: ClipZConfig(**var_config) + for var, var_config in clip_config_dict.items() + } + except FileNotFoundError: + clip_config = None adapter = cls( input_variables=model.input_variables, output_variables=model.output_variables, model=model, + clip_config=clip_config, ) return adapter diff --git a/external/fv3fit/fv3fit/reservoir/config.py b/external/fv3fit/fv3fit/reservoir/config.py index 7e5bef9a8e..0adbd1299b 100644 --- a/external/fv3fit/fv3fit/reservoir/config.py +++ b/external/fv3fit/fv3fit/reservoir/config.py @@ -1,6 +1,6 @@ import dacite from dataclasses import dataclass, asdict -from typing import Sequence, Optional, Set, Tuple +from typing import Sequence, Optional, Set, Tuple, Mapping, Hashable import fsspec import yaml from .._shared.training_config import Hyperparameters @@ -65,6 +65,16 @@ class TransformerConfig: hybrid: Optional[str] = None +@dataclass +class ClipZConfig: + """ Vertical levels **between** start and stop are kept, + levels outside start/stop are clipped off. + """ + + start: Optional[int] = None + stop: Optional[int] = None + + @dataclass class ReservoirTrainingConfig(Hyperparameters): """ @@ -101,11 +111,13 @@ class ReservoirTrainingConfig(Hyperparameters): n_timesteps_synchronize: int input_noise: float seed: int = 0 + zero_fill_clipped_output_levels: bool = False transformers: Optional[TransformerConfig] = None n_jobs: Optional[int] = 1 square_half_hidden_state: bool = False hybrid_variables: Optional[Sequence[str]] = None mask_variable: Optional[str] = None + clip_config: Optional[Mapping[Hashable, ClipZConfig]] = None _METADATA_NAME = "reservoir_training_config.yaml" def __post_init__(self): diff --git a/external/fv3fit/fv3fit/reservoir/train.py b/external/fv3fit/fv3fit/reservoir/train.py index 53a180c733..e1a99ea518 100644 --- a/external/fv3fit/fv3fit/reservoir/train.py +++ b/external/fv3fit/fv3fit/reservoir/train.py @@ -19,6 +19,8 @@ assure_txyz_dims, SynchronziationTracker, get_standard_normalizing_transformer, + clip_batch_data, + zero_fill_clipped_output_levels, ) from .transformers import TransformerGroup, Transformer from .._shared import register_training_function @@ -48,6 +50,8 @@ def _add_input_noise(arr: np.ndarray, stddev: float) -> np.ndarray: def _get_transformers( sample_batch: Mapping[str, tf.Tensor], hyperparameters: ReservoirTrainingConfig ) -> TransformerGroup: + clipped_sample_batch = clip_batch_data(sample_batch, hyperparameters.clip_config) + # Load transformers with specified paths transformers = {} for variable_group in ["input", "output", "hybrid"]: @@ -58,25 +62,25 @@ def _get_transformers( # If input transformer not specified, always create a standard norm transform if "input" not in transformers: transformers["input"] = get_standard_normalizing_transformer( - hyperparameters.input_variables, sample_batch + hyperparameters.input_variables, clipped_sample_batch ) - # If output transformer not specified and output_variables != input_variables, - # create a separate standard norm transform + # Output is not clipped, so use the original sample batch + if hyperparameters.zero_fill_clipped_output_levels: + sample_batch = zero_fill_clipped_output_levels( + sample_batch, hyperparameters.clip_config + ) if "output" not in transformers: - if hyperparameters.output_variables != hyperparameters.input_variables: - transformers["output"] = get_standard_normalizing_transformer( - hyperparameters.output_variables, sample_batch - ) - else: - transformers["output"] = transformers["input"] + transformers["output"] = get_standard_normalizing_transformer( + hyperparameters.output_variables, sample_batch + ) # If hybrid variables transformer not specified, and hybrid variables are defined, # create a separate standard norm transform if "hybrid" not in transformers: if hyperparameters.hybrid_variables is not None: transformers["hybrid"] = get_standard_normalizing_transformer( - hyperparameters.hybrid_variables, sample_batch + hyperparameters.hybrid_variables, clipped_sample_batch ) else: transformers["hybrid"] = transformers["input"] @@ -115,9 +119,17 @@ def train_reservoir_model( train_batches if isinstance(train_batches, Sequence) else [train_batches] ) sample_batch = next(iter(train_batches_sequence[0])) - sample_X = get_ordered_X(sample_batch, hyperparameters.input_variables) + if hyperparameters.zero_fill_clipped_output_levels: + sample_batch = zero_fill_clipped_output_levels( + sample_batch, hyperparameters.clip_config + ) + + # Clipping is done inside this function to preserve full length outputs + transformers = _get_transformers(sample_batch, hyperparameters,) + + clipped_sample_batch = clip_batch_data(sample_batch, hyperparameters.clip_config) + sample_X = get_ordered_X(clipped_sample_batch, hyperparameters.input_variables) - transformers = _get_transformers(sample_batch, hyperparameters) subdomain_config = hyperparameters.subdomain # sample_X[0] is the first data variable, shape elements 1:-1 are the x,y shape @@ -131,7 +143,7 @@ def train_reservoir_model( if hyperparameters.mask_variable is not None: input_mask_array: Optional[np.ndarray] = _get_input_mask_array( - hyperparameters.mask_variable, sample_batch, rank_divider + hyperparameters.mask_variable, clipped_sample_batch, rank_divider ) else: input_mask_array = None @@ -157,9 +169,12 @@ def train_reservoir_model( ) for b, batch_data in enumerate(train_batches): + batch_data_clipped = clip_batch_data( + batch_data, hyperparameters.clip_config + ) input_time_series = process_batch_data( variables=hyperparameters.input_variables, - batch_data=batch_data, + batch_data=batch_data_clipped, rank_divider=rank_divider, autoencoder=transformers.input, trim_halo=False, @@ -171,6 +186,11 @@ def train_reservoir_model( _output_rank_divider_with_overlap = rank_divider.get_new_zdim_rank_divider( z_feature_size=transformers.output.n_latent_dims ) + # don't pass in clipped data here, as clipping is not enabled for outputs + if hyperparameters.zero_fill_clipped_output_levels: + batch_data = zero_fill_clipped_output_levels( + batch_data, hyperparameters.clip_config + ) output_time_series = process_batch_data( variables=hyperparameters.output_variables, batch_data=batch_data, @@ -196,7 +216,7 @@ def train_reservoir_model( hybrid_time_series = process_batch_data( variables=hyperparameters.hybrid_variables, - batch_data=batch_data, + batch_data=batch_data_clipped, rank_divider=_hybrid_rank_divider_w_overlap, autoencoder=transformers.hybrid, trim_halo=True, @@ -279,6 +299,7 @@ def train_reservoir_model( model=model, input_variables=model.input_variables, output_variables=model.output_variables, + clip_config=hyperparameters.clip_config, ) if validation_batches is not None and wandb.run is not None: @@ -287,6 +308,7 @@ def train_reservoir_model( model, val_batches=validation_batches, n_synchronize=hyperparameters.n_timesteps_synchronize, + clip_config=hyperparameters.clip_config, ) log_rmse_z_plots(ds_val, model.output_variables) log_rmse_scalar_metrics(ds_val, model.output_variables) diff --git a/external/fv3fit/fv3fit/reservoir/utils.py b/external/fv3fit/fv3fit/reservoir/utils.py index 12aee8fbc9..a275089c95 100644 --- a/external/fv3fit/fv3fit/reservoir/utils.py +++ b/external/fv3fit/fv3fit/reservoir/utils.py @@ -1,7 +1,7 @@ import logging import numpy as np import tensorflow as tf -from typing import Iterable, Mapping, Optional +from typing import Iterable, Mapping, Optional, Hashable from fv3fit.reservoir.transformers import ( # ReloadableTransformer, @@ -9,6 +9,7 @@ encode_columns, build_concat_and_scale_only_autoencoder, ) +from fv3fit.reservoir.config import ClipZConfig from fv3fit.reservoir.domain2 import RankXYDivider from ._reshaping import stack_array_preserving_last_dim @@ -88,6 +89,7 @@ def square_even_terms(v: np.ndarray, axis=1) -> np.ndarray: def get_ordered_X(X: Mapping[str, tf.Tensor], variables: Iterable[str]): ordered_tensors = [X[v] for v in variables] reshaped_tensors = [assure_txyz_dims(var_tensor) for var_tensor in ordered_tensors] + return reshaped_tensors @@ -121,7 +123,9 @@ def process_batch_data( return rank_divider.get_all_subdomains_with_flat_feature(data_trimmed) -def get_standard_normalizing_transformer(variables, sample_batch): +def get_standard_normalizing_transformer( + variables, sample_batch, +): variable_data = get_ordered_X(sample_batch, variables) variable_data_stacked = [ stack_array_preserving_last_dim(arr).numpy() for arr in variable_data @@ -129,3 +133,55 @@ def get_standard_normalizing_transformer(variables, sample_batch): return build_concat_and_scale_only_autoencoder( variables=variables, X=variable_data_stacked ) + + +def clip_batch_data( + batch: Mapping[str, tf.Tensor], + clip_config: Optional[Mapping[Hashable, ClipZConfig]], +): + dim_ordered_batch = {k: assure_txyz_dims(v) for k, v in batch.items()} + + if clip_config is None: + return dim_ordered_batch + else: + clipped_batch = {} + for var, tensor in batch.items(): + if var in clip_config: + clipped_batch[var] = tensor[ + ..., clip_config[var].start : clip_config[var].stop + ] + else: + clipped_batch[var] = tensor + + return clipped_batch + + +def _zero_fill_last_dim_tensor(tensor, start, stop): + _start = start if start else 0 + _stop = stop if stop else tensor.shape[-1] + return tf.concat( + [ + tf.zeros_like(tensor[..., :_start]), + tensor[..., _start:_stop], + tf.zeros_like(tensor[..., _stop:]), + ], + axis=-1, + ) + + +def zero_fill_clipped_output_levels(batch, clip_config): + """ Zero-fills output levels that have been clipped out of the training data. + """ + if clip_config is None: + return batch + else: + zero_filled_batch = {} + for var, tensor in batch.items(): + if var in clip_config: + zero_filled_batch[var] = _zero_fill_last_dim_tensor( + tensor, clip_config[var].start, clip_config[var].stop + ) + else: + zero_filled_batch[var] = tensor + + return zero_filled_batch diff --git a/external/fv3fit/fv3fit/reservoir/validation.py b/external/fv3fit/fv3fit/reservoir/validation.py index 70140972e4..ef8653571b 100644 --- a/external/fv3fit/fv3fit/reservoir/validation.py +++ b/external/fv3fit/fv3fit/reservoir/validation.py @@ -1,11 +1,12 @@ import numpy as np from scipy.ndimage import generic_filter -from typing import Union, Optional, Sequence +from typing import Union, Optional, Sequence, Mapping, Hashable import xarray as xr import tensorflow as tf import wandb -from fv3fit.reservoir.utils import get_ordered_X +from fv3fit.reservoir.utils import get_ordered_X, clip_batch_data +from fv3fit.reservoir.config import ClipZConfig from fv3fit.reservoir import ( ReservoirComputingModel, HybridReservoirComputingModel, @@ -68,6 +69,13 @@ def _get_predictions_over_batch( return prediction_time_series, imperfect_prediction_time_series +def _get_imperfect_prediction(hybrid_inputs_time_series: Sequence[np.ndarray]): + imperfect_prediction_time_series = [] + for ts in hybrid_inputs_time_series: + imperfect_prediction_time_series.append(ts) + return imperfect_prediction_time_series + + def _time_mean_dataset(variables, arr, label): ds = xr.Dataset() time_mean_error = np.mean(arr, axis=0) @@ -90,7 +98,10 @@ def _get_states_without_overlap( def validation_prediction( - model: ReservoirModel, val_batches: tf.data.Dataset, n_synchronize: int, + model: ReservoirModel, + val_batches: tf.data.Dataset, + n_synchronize: int, + clip_config: Optional[Mapping[Hashable, ClipZConfig]], ): # Initialize hidden state model.reset_state() @@ -99,29 +110,47 @@ def validation_prediction( one_step_imperfect_prediction_time_series = [] target_time_series = [] for batch_data in val_batches: - states_with_overlap_time_series = get_ordered_X( - batch_data, model.input_variables # type: ignore + # outputs are not clipped + output_states_with_overlap_time_series = get_ordered_X( + batch_data, model.output_variables # type: ignore + ) + if clip_config is not None: + batch_input_data = clip_batch_data(batch_data, clip_config) + else: + batch_input_data = batch_data + input_states_with_overlap_time_series = get_ordered_X( + batch_input_data, model.input_variables # type: ignore ) - if isinstance(model, HybridReservoirComputingModel): hybrid_inputs_time_series = get_ordered_X( - batch_data, model.hybrid_variables # type: ignore + batch_input_data, model.hybrid_variables # type: ignore ) hybrid_inputs_time_series = _get_states_without_overlap( hybrid_inputs_time_series, overlap=model.rank_divider.overlap ) + imperfect_prediction_time_series = get_ordered_X( + batch_data, model.hybrid_variables # type: ignore + ) + imperfect_prediction_time_series = _get_states_without_overlap( + imperfect_prediction_time_series, overlap=model.rank_divider.overlap + ) else: hybrid_inputs_time_series = None + imperfect_prediction_time_series = None - batch_predictions, batch_imperfect_predictions = _get_predictions_over_batch( - model, states_with_overlap_time_series, hybrid_inputs_time_series + batch_predictions, _ = _get_predictions_over_batch( + model, input_states_with_overlap_time_series, hybrid_inputs_time_series + ) + batch_imperfect_predictions = _get_imperfect_prediction( + imperfect_prediction_time_series ) one_step_prediction_time_series += batch_predictions one_step_imperfect_prediction_time_series += batch_imperfect_predictions target_time_series.append( _get_states_without_overlap( - states_with_overlap_time_series, overlap=model.rank_divider.overlap + output_states_with_overlap_time_series, + overlap=model.rank_divider.overlap, ) ) target_time_series = np.concatenate(target_time_series, axis=0)[n_synchronize:] @@ -205,16 +234,10 @@ def log_variance_scalar_metrics(ds_val, variables): "prediction", ]: key = f"time_mean_{comparison}_{var}" + print(f"{key} in ds_val: {key in ds_val}") if key in ds_val: variance_key = f"time_mean_{comparison}_2d_variance_zsum_{var}" log_data[variance_key] = _compute_2d_variance_mean_zsum(ds_val[key]) - try: - log_data[f"variance_ratio_{var}"] = ( - log_data[f"time_mean_prediction_2d_variance_zsum_{var}"] - / log_data[f"time_mean_target_2d_variance_zsum_{var}"] - ) - except (KeyError): - pass wandb.log(log_data) diff --git a/external/fv3gfs-fortran b/external/fv3gfs-fortran index a4b4cc16e5..e2617bfa1c 160000 --- a/external/fv3gfs-fortran +++ b/external/fv3gfs-fortran @@ -1 +1 @@ -Subproject commit a4b4cc16e5f5ef3d8163d5fc8b63265d0e804062 +Subproject commit e2617bfa1c1252971ec76e37a20ff478de6c82e2 diff --git a/projects/reservoir/fv3/save_ranks.py b/projects/reservoir/fv3/save_ranks.py index 36c9445142..d481867b30 100644 --- a/projects/reservoir/fv3/save_ranks.py +++ b/projects/reservoir/fv3/save_ranks.py @@ -125,7 +125,11 @@ def get_ordered_dims_extent(dims: dict): _hybrid_data = intake.open_zarr(path).to_dask() _hybrid_data = _hybrid_data.shift(time=args.hybrid_data_time_shift) rename_hybrid_time_shifted_vars.update( - {var: f"{var}_at_next_time_step" for var in _hybrid_data.data_vars} + { + var: f"{var}_at_next_time_step" + for var in _hybrid_data.data_vars + if var in args.variables + } ) data = data.merge(_hybrid_data) diff --git a/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py b/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py index 787a5919fb..4b5d874d02 100644 --- a/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py +++ b/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py @@ -128,7 +128,9 @@ def main(args): nfiles=val_data_config.get("nfiles", None), ) - ds = validation_prediction(model, val_batches, args.n_synchronize,) + ds = validation_prediction( + model, val_batches, args.n_synchronize, clip_config=adapter.clip_config + ) output_file = os.path.join(args.output_path, "offline_diags.nc") diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 59465fe479..1e4d9dae41 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -55,6 +55,7 @@ from runtime.steppers.combine import CombinedStepper from runtime.steppers.reservoir import get_reservoir_steppers from runtime.types import Diagnostics, State, Tendencies, Step + from toolz import dissoc from runtime.nudging import NudgingConfig @@ -139,6 +140,7 @@ def __init__(self, config: UserConfig, wrapper: Any, comm: Any = None,) -> None: self.comm = comm self._timer = pace.util.Timer() self.rank: int = comm.rank + self._steps_taken = 1 namelist = get_namelist() @@ -275,6 +277,7 @@ def _get_stepper( offset_seconds=stepper_config.offset_seconds, record_fields_before_update=stepper_config.record_fields_before_update, n_calls=stepper_config.n_calls, + do_total_precip_update=stepper_config.do_total_precip_update, ) else: return stepper @@ -489,7 +492,6 @@ def _step_prephysics(self) -> Diagnostics: f"Applying prephysics state updates for: {list(state_updates.keys())}" ) self._state.update_mass_conserving(state_updates) - return diagnostics def _compute_postphysics(self) -> Diagnostics: @@ -530,6 +532,10 @@ def _apply_postphysics_to_dycore_state(self) -> Diagnostics: self._state, self._tendencies ) diagnostics.update(stepper_diags) + if net_moistening.size <= 1: + net_moistening = xr.zeros_like(self._state[TOTAL_PRECIP]) + net_moistening.attrs["units"] = "kg/m^2/s" + if self._postphysics_only_diagnostic_ml: rename_diagnostics(diagnostics) else: @@ -572,6 +578,8 @@ def _apply_postphysics_to_dycore_state(self) -> Diagnostics: ), } ) + + self._log_info(f"diags keys: {list(diagnostics.keys())} ") return diagnostics def _increment_reservoir(self) -> Diagnostics: @@ -602,10 +610,11 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: if self._reservoir_predict_stepper.is_diagnostic: # type: ignore rename_diagnostics(diags, label="reservoir_predictor") - state_updates[TOTAL_PRECIP] = precipitation_sum( - self._state[TOTAL_PRECIP], net_moistening, self._timestep, + precip = self._reservoir_predict_stepper.update_precip( # type: ignore + self._state[TOTAL_PRECIP], net_moistening ) - + diags.update(precip) + state_updates[TOTAL_PRECIP] = precip[TOTAL_PRECIP] self._state.update_mass_conserving(state_updates) diags.update({name: self._state[name] for name in self._states_to_output}) @@ -615,12 +624,21 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: "cnvprcp_after_python": self._wrapper.get_diagnostic_by_name( "cnvprcp" ).data_array, - TOTAL_PRECIP_RATE: precipitation_rate( - self._state[TOTAL_PRECIP], self._timestep - ), + TOTAL_PRECIP_RATE: diags["total_precip_rate_res_interval_avg"], } ) + # save state if configured + self._log_info( + f"steps taken: {self._steps_taken}/{self._wrapper.get_step_count()}" + ) + if self._reservoir_predict_stepper.dump_state_at_end: # type: ignore + if self._steps_taken == self._wrapper.get_step_count(): + self._log_info("dumping reservoir state") + self._reservoir_predict_stepper.dump_state( # type: ignore + checkpoint_time=self._state.time + ) + return diags else: return {} @@ -661,4 +679,5 @@ def __iter__( ]: with self._timer.clock(substep.__name__): diagnostics.update(substep()) + self._steps_taken += 1 yield self._state.time, {str(k): v for k, v in diagnostics.items()} diff --git a/workflows/prognostic_c48_run/runtime/steppers/interval.py b/workflows/prognostic_c48_run/runtime/steppers/interval.py index db4b28c9d8..1811151dd6 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/interval.py +++ b/workflows/prognostic_c48_run/runtime/steppers/interval.py @@ -4,13 +4,15 @@ from typing import Tuple, Union, Optional, List import xarray as xr import logging +import vcm from runtime.types import Diagnostics from runtime.steppers.stepper import Stepper from runtime.steppers.machine_learning import MachineLearningConfig from runtime.steppers.prescriber import PrescriberConfig from runtime.nudging import NudgingConfig - +from runtime.diagnostics.compute import KG_PER_M2_PER_M +from runtime.names import SPHUM, DELP, TOTAL_PRECIP logger = logging.getLogger(__name__) @@ -34,6 +36,7 @@ class IntervalConfig: offset_seconds: int = 0 record_fields_before_update: Optional[List[str]] = None n_calls: Optional[int] = None + do_total_precip_update: bool = True class IntervalStepper: @@ -44,6 +47,7 @@ def __init__( offset_seconds: float = 0, n_calls: Optional[int] = None, record_fields_before_update: Optional[List[str]] = None, + do_total_precip_update: bool = True, ): self.start_time = None self.interval = timedelta(seconds=apply_interval_seconds) @@ -52,6 +56,7 @@ def __init__( self._record_fields_before_update = record_fields_before_update or [] self.n_calls = n_calls self._call_count = 0 + self._do_total_precip_update = do_total_precip_update @property def label(self): @@ -90,14 +95,39 @@ def __call__(self, time, state): if self._need_to_update(time) is False: # Diagnostic must be available at all timesteps, not just when # the base stepper is called - return {}, self.get_diagnostics_prior_to_update(state), {} + diags = self.get_diagnostics_prior_to_update(state) + return {}, diags, {} else: logger.info(f"applying interval stepper at time {time}") tendencies, diagnostics, state_updates = self.stepper(time, state) diagnostics.update(self.get_diagnostics_prior_to_update(state)) self._call_count += 1 + + if self._do_total_precip_update and SPHUM in state_updates: + logger.info(f"Updating total precip at time {time}") + state_updates[TOTAL_PRECIP] = self._get_precipitation_update( + state, state_updates + ) + else: + logger.info(f"Not updating total precip at time {time} ") return tendencies, diagnostics, state_updates + def _get_precipitation_update( + self, state, state_updates, + ): + corrective_moistening_integral = ( + vcm.mass_integrate(state_updates[SPHUM] - state[SPHUM], state[DELP], "z") + / KG_PER_M2_PER_M + ) + total_precip_before_limiter = ( + state[TOTAL_PRECIP] - corrective_moistening_integral + ) + total_precip = total_precip_before_limiter.where( + total_precip_before_limiter >= 0, 0 + ) + total_precip.attrs["units"] = "m" + return total_precip + def get_diagnostics(self, state, tendency) -> Tuple[Diagnostics, xr.DataArray]: diags, moistening = self.stepper.get_diagnostics(state, tendency) diags.update(self.get_diagnostics_prior_to_update(state)) diff --git a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py index 9ae0e987d0..b4c05ae40e 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py +++ b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py @@ -4,6 +4,7 @@ import pandas as pd import xarray as xr from datetime import timedelta +import os from typing import ( Optional, MutableMapping, @@ -18,7 +19,8 @@ import fv3fit from fv3fit._shared.halos import append_halos_using_mpi from fv3fit.reservoir.adapters import ReservoirDatasetAdapter -from runtime.names import SST, SPHUM, TEMP +import vcm +from runtime.names import SST, SPHUM, TEMP, PHYSICS_PRECIP_RATE, TOTAL_PRECIP from runtime.tendency import add_tendency, tendencies_from_state_updates from runtime.diagnostics import ( enforce_heating_and_moistening_tendency_constraints, @@ -31,6 +33,25 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class TaperConfig: + cutoff: int + rate: float + taper_dim: str = "z" + + def blend(self, prediction: xr.DataArray, input: xr.DataArray) -> xr.DataArray: + n_levels = len(prediction[self.taper_dim]) + prediction_scaling = xr.DataArray( + vcm.vertical_tapering_scale_factors( + n_levels=n_levels, cutoff=self.cutoff, rate=self.rate + ), + dims=[self.taper_dim], + ) + input_scaling = 1 - prediction_scaling + + return input_scaling * input + prediction_scaling * prediction + + @dataclasses.dataclass class ReservoirConfig: """ @@ -64,6 +85,10 @@ class ReservoirConfig: rename_mapping: NameDict = dataclasses.field(default_factory=dict) hydrostatic: bool = False mse_conserving_limiter: bool = False + interval_average_precipitation: bool = False + taper_blending: Optional[Mapping] = None + dump_state_at_end: bool = False + state_checkpoint_group: Optional[str] = None def __post_init__(self): # This handles cases in automatic config writing where json/yaml @@ -120,6 +145,49 @@ def __call__(self, state: str): ) +class PrecipTracker: + def __init__(self, reservoir_timestep_seconds: float): + self.reservoir_timestep_seconds = reservoir_timestep_seconds + self.physics_precip_averager = TimeAverageInputs([PHYSICS_PRECIP_RATE]) + self._air_temperature_at_previous_interval = None + self._specific_humidity_at_previous_interval = None + + def increment_physics_precip_rate(self, physics_precip_rate): + self.physics_precip_averager.increment_running_average( + {PHYSICS_PRECIP_RATE: physics_precip_rate} + ) + + def interval_avg_precip_rates(self, net_moistening_due_to_reservoir): + physics_precip_rate = self.physics_precip_averager.get_averages()[ + PHYSICS_PRECIP_RATE + ] + total_precip_rate = physics_precip_rate - net_moistening_due_to_reservoir + total_precip_rate = total_precip_rate.where(total_precip_rate >= 0, 0) + reservoir_precip_rate = total_precip_rate - physics_precip_rate + return { + "total_precip_rate_res_interval_avg": total_precip_rate, + "physics_precip_rate_res_interval_avg": physics_precip_rate, + "reservoir_precip_rate_res_interval_avg": reservoir_precip_rate, + } + + def accumulated_precip_update( + self, + physics_precip_total_over_model_timestep, + reservoir_precip_rate_over_res_interval, + reservoir_timestep, + ): + # Since the reservoir correction is only applied every reservoir_timestep, + # all of the precip due to the reservoir is put into the accumulated precip + # in the model timestep at update time. + m_per_mm = 1 / 1000 + reservoir_total_precip = ( + reservoir_precip_rate_over_res_interval * reservoir_timestep * m_per_mm + ) + total_precip = physics_precip_total_over_model_timestep + reservoir_total_precip + total_precip.attrs["units"] = "m" + return total_precip + + class TimeAverageInputs: """ Copy of time averaging components from runtime.diagnostics.manager to @@ -179,6 +247,7 @@ def __init__( reservoir_timestep: timedelta, model_timestep: float, synchronize_steps: int, + model_path: str, state_machine: Optional[_FiniteStateMachine] = None, diagnostic_only: bool = False, input_averager: Optional[TimeAverageInputs] = None, @@ -186,7 +255,13 @@ def __init__( warm_start: bool = False, hydrostatic: bool = False, mse_conserving_limiter: bool = False, + precip_tracker: Optional[PrecipTracker] = None, + taper_blending: Optional[TaperConfig] = None, + dump_state_at_end: bool = False, + state_checkpoint_group: Optional[str] = None, ): + self.dump_state_at_end = dump_state_at_end + self.state_checkpoint_group = state_checkpoint_group self.model = model self.synchronize_steps = synchronize_steps self.initial_time = init_time @@ -197,6 +272,9 @@ def __init__( self.warm_start = warm_start self.hydrostatic = hydrostatic self.mse_conserving_limiter = mse_conserving_limiter + self.precip_tracker = precip_tracker + self.taper_blending = taper_blending + self.model_path = model_path if state_machine is None: state_machine = _FiniteStateMachine() @@ -217,6 +295,22 @@ def __init__( rename_mapping = cast(NameDict, {}) self.rename_mapping = rename_mapping + def dump_state(self, checkpoint_time: cftime.DatetimeJulian): + # Save the current model state to its path + self.model.dump_state( + os.path.join(self.model_path, "hybrid_reservoir_model", "reservoir") + ) + + checkpoint_path = os.path.join( + self.model_path, + "reservoir_state_checkpoints", + self.state_checkpoint_group or "", + checkpoint_time.strftime("%Y%m%d.%H%M%S"), + ) + # Save the current model state to the checkpoint directory + self.model.dump_state(checkpoint_path) + logger.info + @property def completed_sync_steps(self): return self._state_machine.completed_increments @@ -266,8 +360,8 @@ def _get_inputs_from_state(self, state): ) except RuntimeError: raise ValueError( - "MPI not available or tile dimension does not exist in state fields" - " during reservoir increment update" + "MPI not available or tile dimension does not exist in state " + "fields during reservoir increment update" ) reservoir_inputs = rc_in_with_halos @@ -329,12 +423,19 @@ def predict(self, inputs, state): self._state_machine(self._state_machine.PREDICT) result = self.model.predict(inputs) + output_state = rename_dataset_members(result, self.rename_mapping) diags = rename_dataset_members( output_state, {k: f"{k}_{self.DIAGS_OUTPUT_SUFFIX}" for k in output_state} ) - + if self.taper_blending is not None: + input_renaming = { + k: v for k, v in self.rename_mapping.items() if k in inputs + } + output_state = self.taper_blending.blend( + output_state, inputs.rename(input_renaming) + ) for k, v in output_state.items(): v.attrs["units"] = state[k].attrs.get("units", "unknown") @@ -376,6 +477,11 @@ def __call__(self, time, state): if self.input_averager is not None: self.input_averager.increment_running_average(inputs) + if self.precip_tracker is not None: + self.precip_tracker.increment_physics_precip_rate( + state[PHYSICS_PRECIP_RATE] + ) + if self._is_rc_update_step(time): logger.info(f"Reservoir model predict at time {time}") if self.input_averager is not None: @@ -423,6 +529,13 @@ def __call__(self, time, state): tendencies=tendency_updates_from_constraints, dt=self.model_timestep, ) + # Adjust corrective tendencies to be averages over + # the full reservoir timestep + for key in tendency_updates_from_constraints: + if key != "specific_humidity_limiter_active": + tendency_updates_from_constraints[key] *= ( + self.model_timestep / self.timestep.total_seconds() + ) tendencies.update(tendency_updates_from_constraints) else: @@ -434,6 +547,24 @@ def get_diagnostics(self, state, tendency): diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic) return diags, diags[f"net_moistening_due_to_{self.label}"] + def update_precip( + self, physics_precip, net_moistening_due_to_reservoir, + ): + diags = {} + + # running average gets reset in this call + precip_rates = self.precip_tracker.interval_avg_precip_rates( + net_moistening_due_to_reservoir + ) + diags.update(precip_rates) + + diags[TOTAL_PRECIP] = self.precip_tracker.accumulated_precip_update( + physics_precip, + diags["reservoir_precip_rate_res_interval_avg"], + self.timestep.total_seconds(), + ) + return diags + def open_rc_model(path: str) -> ReservoirDatasetAdapter: return cast(ReservoirDatasetAdapter, fv3fit.load(path)) @@ -467,7 +598,8 @@ def get_reservoir_steppers( using the stepped underlying model + incremented RC state. """ try: - model = open_rc_model(config.models[rank]) + model_path = config.models[rank] + model = open_rc_model(model_path) except KeyError: raise KeyError( f"No reservoir model path found for rank {rank}. " @@ -478,7 +610,16 @@ def get_reservoir_steppers( increment_averager, predict_averager = _get_time_averagers( model, config.time_average_inputs ) - + _precip_tracker_kwargs = {} + if config.interval_average_precipitation: + _precip_tracker_kwargs["precip_tracker"] = PrecipTracker( + reservoir_timestep_seconds=rc_tdelta.total_seconds(), + ) + if config.taper_blending is not None: + if len({"cutoff", "rate"}.intersection(config.taper_blending.keys())) == 2: + taper_blending: Optional[TaperConfig] = TaperConfig(**config.taper_blending) + else: + taper_blending = None incrementer = ReservoirIncrementOnlyStepper( model, init_time, @@ -489,6 +630,8 @@ def get_reservoir_steppers( rename_mapping=config.rename_mapping, warm_start=config.warm_start, model_timestep=model_timestep, + model_path=model_path, + dump_state_at_end=config.dump_state_at_end, ) predictor = ReservoirPredictStepper( model, @@ -503,5 +646,10 @@ def get_reservoir_steppers( model_timestep=model_timestep, hydrostatic=config.hydrostatic, mse_conserving_limiter=config.mse_conserving_limiter, + taper_blending=taper_blending, + model_path=model_path, + dump_state_at_end=config.dump_state_at_end, + state_checkpoint_group=config.state_checkpoint_group, + **_precip_tracker_kwargs, ) return incrementer, predictor diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out index 213a2d3f26..efa100c0bb 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out @@ -435,7 +435,9 @@ prephysics: null radiation_scheme: null reservoir_corrector: diagnostic_only: false + dump_state_at_end: false hydrostatic: false + interval_average_precipitation: false models: 0: gs://vcm-ml-scratch/rc-model-tile-0 1: gs://vcm-ml-scratch/rc-model-tile-1 @@ -446,7 +448,9 @@ reservoir_corrector: mse_conserving_limiter: false rename_mapping: {} reservoir_timestep: 900s + state_checkpoint_group: null synchronize_steps: 12 + taper_blending: null time_average_inputs: false warm_start: false scikit_learn: null diff --git a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py index 7d078f5ea5..cc2d3335fe 100644 --- a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py +++ b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py @@ -134,6 +134,7 @@ def get_mock_ReservoirSteppers(): MODEL_TIMESTEP, 2, state_machine=state_machine, + model_path="path/to/model", ) predictor = ReservoirPredictStepper( @@ -143,6 +144,7 @@ def get_mock_ReservoirSteppers(): MODEL_TIMESTEP, 2, state_machine=state_machine, + model_path="path/to/model", ) return incrementer, predictor