Skip to content

Commit

Permalink
Update Held-Suarez forcing to support model parallelism.
Browse files Browse the repository at this point in the history
This small update should enable model parallelism with Held-Suarez forcing -- the velocity tendencies should be a tuple, not an array with a leading dimension of size 2. I've updated the documentation and type annotations of DiagnosticState.cos_lat_u to match.

xref: #45

TODO: write a unit test to verify that running Held-Suarez with model parallelism works.
PiperOrigin-RevId: 660470844
  • Loading branch information
shoyer authored and Dinosaur authors committed Aug 7, 2024
1 parent ede9b37 commit 3b4b923
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
41 changes: 24 additions & 17 deletions dinosaur/held_suarez.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
Held, I. M., and M. J. Suarez, 1994: "A proposal for the intercomparison of
the dynamical cores of atmospheric general circulation models."
Bulletin of the American Meteorological Society, 75, 1825–1830.
"""

from dinosaur import coordinate_systems
from dinosaur import primitive_equations
from dinosaur import scales
from dinosaur import time_integration
from dinosaur import typing

import jax
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -92,49 +91,57 @@ def __init__(

def kv(self):
kv_coeff = self.kf * (
np.maximum(0, (self.sigma - self.sigma_b) / (1 - self.sigma_b)))
np.maximum(0, (self.sigma - self.sigma_b) / (1 - self.sigma_b))
)
return kv_coeff[:, np.newaxis, np.newaxis]

def kt(self):
cutoff = np.maximum(0, (self.sigma - self.sigma_b) / (1 - self.sigma_b))
return self.ka + (self.ks - self.ka) * (
cutoff[:, np.newaxis, np.newaxis] * np.cos(self.lat)**4)
cutoff[:, np.newaxis, np.newaxis] * np.cos(self.lat) ** 4
)

def equilibrium_temperature(self, nodal_surface_pressure):
p_over_p0 = (
self.sigma[:, np.newaxis, np.newaxis] * nodal_surface_pressure /
self.p0)
self.sigma[:, np.newaxis, np.newaxis] * nodal_surface_pressure / self.p0
)
temperature = p_over_p0**self.physics_specs.kappa * (
self.maxT - self.dTy * np.sin(self.lat)**2 -
self.dThz * jnp.log(p_over_p0) * np.cos(self.lat)**2)
self.maxT
- self.dTy * np.sin(self.lat) ** 2
- self.dThz * jnp.log(p_over_p0) * np.cos(self.lat) ** 2
)
return jnp.maximum(self.minT, temperature)

def explicit_terms(
self, state: primitive_equations.State
) -> primitive_equations.State:
"""Computes explicit tendencies due to Held-Suarez forcing."""
aux_state = primitive_equations.compute_diagnostic_state(
state=state, coords=self.coords)
state=state, coords=self.coords
)

# Nodal velocity tendencies
# "velocity" here is `velocity / cos(lat)`
nodal_velocity = (
jnp.stack(aux_state.cos_lat_u) / self.coords.horizontal.cos_lat**2)
nodal_velocity_tendency = -self.kv() * nodal_velocity
nodal_velocity_tendency = jax.tree.map(
lambda x: -self.kv() * x / self.coords.horizontal.cos_lat**2,
aux_state.cos_lat_u,
)

# Nodal temperature tendency
nodal_temperature = (
self.reference_temperature[:, np.newaxis, np.newaxis] +
aux_state.temperature_variation)
self.reference_temperature[:, np.newaxis, np.newaxis]
+ aux_state.temperature_variation
)
nodal_log_surface_pressure = self.coords.horizontal.to_nodal(
state.log_surface_pressure)
state.log_surface_pressure
)
nodal_surface_pressure = jnp.exp(nodal_log_surface_pressure)
Teq = self.equilibrium_temperature(nodal_surface_pressure)
nodal_temperature_tendency = -self.kt() * (nodal_temperature - Teq)

# Convert to modal
temperature_tendency = self.coords.horizontal.to_modal(
nodal_temperature_tendency)
nodal_temperature_tendency
)
velocity_tendency = self.coords.horizontal.to_modal(nodal_velocity_tendency)
vorticity_tendency = self.coords.horizontal.curl_cos_lat(velocity_tendency)
divergence_tendency = self.coords.horizontal.div_cos_lat(velocity_tendency)
Expand Down
5 changes: 3 additions & 2 deletions dinosaur/primitive_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ class DiagnosticState:
vorticity: nodal values of the vorticity field of shape [h, q, t].
divergence: nodal values of the divergence field of shape [h, q, t].
temperature_variation: nodal values of the T' field of shape [h, q, t].
cos_lat_u: (2,) nodal values of cosθ * velocity_vector of shape [h, q, t].
cos_lat_u: tuple of nodal values of cosθ * velocity_vector, each of shape
[h, q, t].
sigma_dot_explicit: nodal values of d𝜎/dt due to pressure gradient terms
`u · ∇(log(ps))` of shape [h, q, t].
sigma_dot_full: nodal values of d𝜎/dt due to all terms of shape [h, q, t].
Expand All @@ -152,7 +153,7 @@ class DiagnosticState:
vorticity: Array
divergence: Array
temperature_variation: Array
cos_lat_u: Array
cos_lat_u: tuple[Array, Array]
sigma_dot_explicit: Array
sigma_dot_full: Array
cos_lat_grad_log_sp: Array
Expand Down

0 comments on commit 3b4b923

Please sign in to comment.