Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def exit_function(*args, **kwargs):
damping_coefficients=self.state.damping_coefficients,
config=self.config.dycore_config,
phis=self.state.dycore_state.phis,
state=self.state.dycore_state,
)

self.dycore.update_state(
Expand Down Expand Up @@ -313,11 +312,8 @@ def exit_function(*args, **kwargs):
namelist=self.config.physics_config,
comm=communicator,
grid_info=self.state.driver_grid_data,
state=self.state.dycore_state,
quantity_factory=self.quantity_factory,
dycore_only=self.config.dycore_only,
apply_tendencies=self.config.apply_tendencies,
tendency_state=self.state.tendency_state,
)
else:
# Make sure those are set to None to raise any issues
Expand Down Expand Up @@ -439,6 +435,7 @@ def _step_dynamics(
):
self.dycore.step_dynamics(
state=state,
tracers_dict=state.tracers_as_array(),
timer=timer,
)

Expand All @@ -454,9 +451,9 @@ def _step_physics(self, timestep: float):
self.end_of_step_update(
dycore_state=self.state.dycore_state,
phy_state=self.state.physics_state,
u_dt=self.state.tendency_state.u_dt.storage,
v_dt=self.state.tendency_state.v_dt.storage,
pt_dt=self.state.tendency_state.pt_dt.storage,
u_dt=self.state.tendency_state.u_dt,
v_dt=self.state.tendency_state.v_dt,
pt_dt=self.state.tendency_state.pt_dt,
dt=float(timestep),
)

Expand Down
63 changes: 18 additions & 45 deletions dsl/pace/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from typing import List, Optional
from typing import List, Optional, Union

import numpy as np

from pace.dsl.dace.orchestration import dace_inhibitor
from pace.util.communicator import CubedSphereCommunicator
from pace.util.halo_updater import HaloUpdater
from pace.util.halo_updater import HaloUpdater, VectorInterfaceHaloUpdater


class WrappedHaloUpdater:
Expand All @@ -17,57 +17,30 @@ class WrappedHaloUpdater:

def __init__(
self,
updater: HaloUpdater,
state,
qty_x_names: List[str],
qty_y_names: List[str] = None,
comm: Optional[CubedSphereCommunicator] = None,
updater: Union[HaloUpdater, VectorInterfaceHaloUpdater],
) -> None:
self._updater = updater
self._state = state
self._qtx_x_names = qty_x_names
self._qtx_y_names = qty_y_names
self._comm = comm

@dace_inhibitor
def start(self):
if self._qtx_y_names is None:
if dataclasses.is_dataclass(self._state):
self._updater.start(
[self._state.__getattribute__(x) for x in self._qtx_x_names]
)
elif isinstance(self._state, dict):
self._updater.start([self._state[x] for x in self._qtx_x_names])
else:
raise NotImplementedError
else:
if dataclasses.is_dataclass(self._state):
self._updater.start(
[self._state.__getattribute__(x) for x in self._qtx_x_names],
[self._state.__getattribute__(y) for y in self._qtx_y_names],
)
elif isinstance(self._state, dict):
self._updater.start(
[self._state[x] for x in self._qtx_x_names],
[self._state[y] for y in self._qtx_y_names],
)
else:
raise NotImplementedError
def start(
self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None
):
assert isinstance(self._updater, HaloUpdater)
self._updater.start(arrays_x, arrays_y)

@dace_inhibitor
def wait(self):
self._updater.wait()

@dace_inhibitor
def update(self):
self.start()
def update(
self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None
):
self.start(arrays_x, arrays_y)
self.wait()

@dace_inhibitor
def interface(self):
assert len(self._qtx_x_names) == 1
assert len(self._qtx_y_names) == 1
self._comm.synchronize_vector_interfaces(
self._state.__getattribute__(self._qtx_x_names[0]),
self._state.__getattribute__(self._qtx_y_names[0]),
)
def interface(self, arrays_x: np.ndarray, arrays_y: np.ndarray):
assert isinstance(self._updater, VectorInterfaceHaloUpdater)
request = self._updater.start_synchronize_vector_interfaces(arrays_x, arrays_y)
request.wait()
6 changes: 6 additions & 0 deletions dsl/pace/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
halo = 3
origin = (halo, halo, 0)

# nq is actually given by ncnst - pnats, where those are given in atmosphere.F90 by:
# ncnst = Atm(mytile)%ncnst
# pnats = Atm(mytile)%flagstruct%pnats
# here we hard-coded it because 8 is the only supported value, refactor this later!
NQ = 8 # state.nq_tot - spec.namelist.dnats

# TODO get from field_table
tracer_variables = [
"qvapor",
Expand Down
7 changes: 2 additions & 5 deletions examples/notebooks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,7 @@ def run_finite_volume_fluxprep(
return flux_prep


def build_tracer_advection(
stencil_configuration: Dict[str, Any], tracers: Dict[str, Quantity]
) -> TracerAdvection:
def build_tracer_advection(stencil_configuration: Dict[str, Any]) -> TracerAdvection:
"""
Use: tracer_advection =
build_tracer_advection(stencil_configuration, tracers)
Expand Down Expand Up @@ -949,7 +947,6 @@ def build_tracer_advection(
fvtp_2d,
stencil_configuration["grid_data"],
stencil_configuration["communicator"],
tracers,
)

return tracer_advection
Expand Down Expand Up @@ -993,7 +990,7 @@ def prepare_everything_for_advection(
timestep,
)

tracer_advection = build_tracer_advection(stencil_configuration, tracers)
tracer_advection = build_tracer_advection(stencil_configuration)

tracer_advection_data = {
"tracers": tracers,
Expand Down
1 change: 0 additions & 1 deletion fv3core/examples/standalone/runfile/acoustics.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def driver(
dycore_config.acoustic_dynamics,
input_data["pfull"],
input_data["phis"],
state,
)

# warm-up timestep.
Expand Down
5 changes: 2 additions & 3 deletions fv3core/examples/standalone/runfile/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def setup_dycore(
damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms),
config=dycore_config,
phis=state.phis,
state=state,
)
dycore.update_state(
conserve_total_energy=dycore_config.consv_te,
Expand Down Expand Up @@ -310,7 +309,7 @@ def setup_dycore(
# warmup/compilation from the internal timers
if rank == 0:
print("timestep 1")
dycore.step_dynamics(state, timer)
dycore.step_dynamics(state, state.tracers_as_array(), timer)

if profiler is not None:
profiler.enable()
Expand All @@ -324,7 +323,7 @@ def setup_dycore(
with timestep_timer.clock("mainloop"):
if rank == 0:
print(f"timestep {i+2}")
dycore.step_dynamics(state, timer=timestep_timer)
dycore.step_dynamics(state, state.tracers_as_array(), timer=timestep_timer)
times_per_step.append(timestep_timer.times)
hits_per_step.append(timestep_timer.hits)
timestep_timer.reset()
Expand Down
19 changes: 18 additions & 1 deletion fv3core/fv3core/initialization/dycore_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field, fields
from typing import Any, Mapping
from typing import Any, Dict, List, Mapping

import numpy as np
import xarray as xr

import pace.dsl.gt4py_utils as gt_utils
Expand Down Expand Up @@ -372,5 +373,21 @@ def xr_dataset(self):
)
return xr.Dataset(data_vars=data_vars)

@property
def tracers(self) -> List[pace.util.Quantity]:
return [self.__getattribute__(x) for x in DycoreState.tracer_names()]

def tracers_as_array(self) -> Dict[str, np.ndarray]:
all_tracers = {
name: self.__getattribute__(name).data
for name in DycoreState.tracer_names()
}
all_tracers.pop("qcld")
return all_tracers

@classmethod
def tracer_names(cls) -> List[str]:
return gt_utils.tracer_variables

def __getitem__(self, item):
return getattr(self, item)
Loading