From a9ec598e0353511a13429d8a2a35cb2788254117 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 17 Aug 2022 05:27:29 -0700 Subject: [PATCH 1/8] VectorInterfaceHaloUpdater reasons on QuantityHaloSpec & arrays Communicator keeps a Quantity-based API QuantityHaloSpec gets it's own file (re-use) --- dsl/pace/dsl/dace/wrapped_halo_exchange.py | 14 +- pace-util/pace/util/communicator.py | 28 +++- pace-util/pace/util/halo_data_transformer.py | 18 +-- .../pace/util/halo_quantity_specification.py | 39 +++++ pace-util/pace/util/halo_updater.py | 136 +++++++++++++----- 5 files changed, 170 insertions(+), 65 deletions(-) create mode 100644 pace-util/pace/util/halo_quantity_specification.py diff --git a/dsl/pace/dsl/dace/wrapped_halo_exchange.py b/dsl/pace/dsl/dace/wrapped_halo_exchange.py index ad88fb118..debe60d2d 100644 --- a/dsl/pace/dsl/dace/wrapped_halo_exchange.py +++ b/dsl/pace/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ import dataclasses -from typing import List, Optional +from typing import List, Optional, Union from pace.dsl.dace.orchestration import dace_inhibitor from pace.util.communicator import CubedSphereCommunicator -from pace.util.halo_updater import HaloUpdater +from pace.util.halo_updater import HaloUpdater, VectorInterfaceHaloUpdater class WrappedHaloUpdater: @@ -17,7 +17,7 @@ class WrappedHaloUpdater: def __init__( self, - updater: HaloUpdater, + updater: Union[HaloUpdater, VectorInterfaceHaloUpdater], state, qty_x_names: List[str], qty_y_names: List[str] = None, @@ -65,9 +65,11 @@ def update(self): @dace_inhibitor def interface(self): + assert isinstance(self._updater, VectorInterfaceHaloUpdater) assert len(self._qtx_x_names) == 1 assert len(self._qtx_y_names) == 1 - self._comm.synchronize_vector_interfaces( - self._state.__getattribute__(self._qtx_x_names[0]), - self._state.__getattribute__(self._qtx_y_names[0]), + request = self._updater.start_synchronize_vector_interfaces( + self._state.__getattribute__(self._qtx_x_names[0]).data, + self._state.__getattribute__(self._qtx_y_names[0]).data, ) + request.wait() diff --git a/pace-util/pace/util/communicator.py b/pace-util/pace/util/communicator.py index 709468a4e..47379dc14 100644 --- a/pace-util/pace/util/communicator.py +++ b/pace-util/pace/util/communicator.py @@ -8,7 +8,7 @@ from ._timing import NullTimer, Timer from .boundary import Boundary from .buffer import array_buffer, recv_buffer, send_buffer -from .halo_data_transformer import QuantityHaloSpec +from .halo_quantity_specification import QuantityHaloSpec from .halo_updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from .partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from .quantity import Quantity, QuantityMetadata @@ -479,14 +479,20 @@ def start_synchronize_vector_interfaces( """ halo_updater = VectorInterfaceHaloUpdater( comm=self.comm, + qty_x_spec=QuantityHaloSpec.from_quantity(x_quantity, -1), + qty_y_spec=QuantityHaloSpec.from_quantity(y_quantity, -1), boundaries=self.boundaries, force_cpu=self._force_cpu, timer=self.timer, ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req = halo_updater.start_synchronize_vector_interfaces( + x_quantity.data, y_quantity.data + ) return req - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + def get_scalar_halo_updater( + self, specifications: List[QuantityHaloSpec] + ) -> HaloUpdater: if len(specifications) == 0: raise RuntimeError("Cannot create updater with specifications list") if specifications[0].n_points == 0: @@ -504,7 +510,7 @@ def get_vector_halo_updater( self, specifications_x: List[QuantityHaloSpec], specifications_y: List[QuantityHaloSpec], - ): + ) -> HaloUpdater: if len(specifications_x) == 0 and len(specifications_y) == 0: raise RuntimeError("Cannot create updater with empty specifications list") if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: @@ -519,6 +525,20 @@ def get_vector_halo_updater( self.timer, ) + def get_vector_interface_halo_updater( + self, + specification_x: QuantityHaloSpec, + specification_y: QuantityHaloSpec, + ) -> VectorInterfaceHaloUpdater: + return VectorInterfaceHaloUpdater( + comm=self.comm, + qty_x_spec=specification_x, + qty_y_spec=specification_y, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + def _get_halo_tag(self) -> int: self._last_halo_tag += 1 return self._last_halo_tag diff --git a/pace-util/pace/util/halo_data_transformer.py b/pace-util/pace/util/halo_data_transformer.py index 794714c86..cfaa0b763 100644 --- a/pace-util/pace/util/halo_data_transformer.py +++ b/pace-util/pace/util/halo_data_transformer.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple from uuid import UUID, uuid1 import numpy as np @@ -14,27 +14,13 @@ unpack_scalar_f64_kernel, unpack_vector_f64_kernel, ) +from .halo_quantity_specification import QuantityHaloSpec from .quantity import Quantity from .rotate import rotate_scalar_data, rotate_vector_data from .types import NumpyModule from .utils import device_synchronize -@dataclass -class QuantityHaloSpec: - """Describe the memory to be exchanged, including size of the halo.""" - - n_points: int - strides: Tuple[int] - itemsize: int - shape: Tuple[int] - origin: Tuple[int, ...] - extent: Tuple[int, ...] - dims: Tuple[str, ...] - numpy_module: NumpyModule - dtype: Any - - # ------------------------------------------------------------------------ # Simple pool of streams to lower the driver pressure # Use _pop/_push_stream to manipulate the pool diff --git a/pace-util/pace/util/halo_quantity_specification.py b/pace-util/pace/util/halo_quantity_specification.py new file mode 100644 index 000000000..bcb6b2de0 --- /dev/null +++ b/pace-util/pace/util/halo_quantity_specification.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from .types import NumpyModule +from typing import Tuple, Any +from .quantity import Quantity + + +@dataclass +class QuantityHaloSpec: + """Describe the memory to be exchanged. + + Specification needs to cover all aspect of the memory layout for + borth scalar, vector and interface fields, for their halo to be exchanged. + `numpy_module` carries a numpy-like (numpy, cupy) module that will be + used to direct the exchange on the right device. + """ + + n_points: int + strides: Tuple[int] + itemsize: int + shape: Tuple[int] + origin: Tuple[int, ...] + extent: Tuple[int, ...] + dims: Tuple[str, ...] + numpy_module: NumpyModule + dtype: Any + + @classmethod + def from_quantity(cls, quantity: Quantity, n_points: int) -> "QuantityHaloSpec": + return QuantityHaloSpec( + n_points=n_points, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + shape=quantity.data.shape, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=quantity.np, + dtype=quantity.data.dtype, + ) diff --git a/pace-util/pace/util/halo_updater.py b/pace-util/pace/util/halo_updater.py index 3a98c5c74..60cab950f 100644 --- a/pace-util/pace/util/halo_updater.py +++ b/pace-util/pace/util/halo_updater.py @@ -7,12 +7,9 @@ from ._timing import NullTimer, Timer from .boundary import Boundary from .buffer import Buffer -from .halo_data_transformer import ( - HaloDataTransformer, - HaloExchangeSpec, - QuantityHaloSpec, -) -from .quantity import Quantity +from .halo_data_transformer import HaloDataTransformer, HaloExchangeSpec +from .halo_quantity_specification import QuantityHaloSpec +from .quantity import BoundaryArrayView, Quantity from .rotate import rotate_scalar_data from .types import AsyncRequest, NumpyModule from .utils import device_synchronize @@ -333,15 +330,15 @@ def wait(self): Buffer.push_to_cache(transfer_buffer) -def on_c_grid(x_quantity, y_quantity): +def on_c_grid(x_spec: QuantityHaloSpec, y_spec: QuantityHaloSpec): if ( - constants.X_DIM not in x_quantity.dims - or constants.Y_INTERFACE_DIM not in x_quantity.dims + constants.X_DIM not in x_spec.dims + or constants.Y_INTERFACE_DIM not in x_spec.dims ): return False if ( - constants.Y_DIM not in y_quantity.dims - or constants.X_INTERFACE_DIM not in y_quantity.dims + constants.Y_DIM not in y_spec.dims + or constants.X_INTERFACE_DIM not in y_spec.dims ): return False else: @@ -349,9 +346,19 @@ def on_c_grid(x_quantity, y_quantity): class VectorInterfaceHaloUpdater: + """Exchange halo on information between ranks for data living on the interface. + + This class reasons on QuantityHaloSpec for initialization and assumes the arrays given + to the start_synchronize_vector_interfaces adhere to those specs. + + See start_synchronize_vector_interfaces for details on interface exchange. + """ + def __init__( self, comm, + qty_x_spec: QuantityHaloSpec, + qty_y_spec: QuantityHaloSpec, boundaries: Mapping[int, Boundary], force_cpu: bool = False, timer: Optional[Timer] = None, @@ -360,6 +367,8 @@ def __init__( Args: comm: mpi4py.Comm object + qty_x_spec: halo specification for data to exchange on the X-axis + qty_y_spec: halo specification for data to exchange on the Y-axis partitioner: cubed sphere partitioner force_cpu: Force all communication to go through central memory. Optional. timer: Time communication operations. Optional. @@ -369,13 +378,15 @@ def __init__( self._force_cpu = force_cpu self.comm = comm self.boundaries = boundaries + self._qty_x_spec = qty_x_spec + self._qty_y_spec = qty_y_spec def _get_halo_tag(self) -> int: self._last_halo_tag += 1 return self._last_halo_tag def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity + self, x_array: np.ndarray, y_array: np.ndarray ) -> HaloUpdateRequest: """ Synchronize shared points at the edges of a vector interface variable. @@ -389,70 +400,91 @@ def start_synchronize_vector_interfaces( rotation of vector quantities needed to move data across the edge. Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized + x_array: the x-component data to be synchronized + y_array: the y-component data to be synchronized Returns: request: an asynchronous request object with a .wait() method """ - if not on_c_grid(x_quantity, y_quantity): + if not on_c_grid(self._qty_x_spec, self._qty_y_spec): raise ValueError("vector must be defined on Arakawa C-grid") device_synchronize() tag = self._get_halo_tag() - send_requests = self._Isend_vector_shared_boundary( - x_quantity, y_quantity, tag=tag - ) - recv_requests = self._Irecv_vector_shared_boundary( - x_quantity, y_quantity, tag=tag - ) + send_requests = self._Isend_vector_shared_boundary(x_array, y_array, tag=tag) + recv_requests = self._Irecv_vector_shared_boundary(x_array, y_array, tag=tag) return HaloUpdateRequest(send_requests, recv_requests, self.timer) def _Isend_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_array: np.ndarray, y_array: np.ndarray, tag=0 ) -> _HaloRequestSendList: + # South boundary south_boundary = self.boundaries[constants.SOUTH] - west_boundary = self.boundaries[constants.WEST] - south_data = x_quantity.view.southwest.sel( + southwest_x_view = BoundaryArrayView( + x_array, + constants.SOUTHWEST, + self._qty_x_spec.dims, + self._qty_x_spec.origin, + self._qty_x_spec.extent, + ) + south_data = southwest_x_view.sel( **{ constants.Y_INTERFACE_DIM: 0, constants.X_DIM: slice( - 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] + 0, + self._qty_x_spec.extent[ + self._qty_x_spec.dims.index(constants.X_DIM) + ], ), } ) south_data = rotate_scalar_data( south_data, [constants.X_DIM], - x_quantity.np, + self._qty_x_spec.numpy_module, -south_boundary.n_clockwise_rotations, ) if south_boundary.n_clockwise_rotations in (3, 2): south_data = -south_data - west_data = y_quantity.view.southwest.sel( + + # West boundary + west_boundary = self.boundaries[constants.WEST] + southwest_y_view = BoundaryArrayView( + y_array, + constants.SOUTHWEST, + self._qty_y_spec.dims, + self._qty_y_spec.origin, + self._qty_y_spec.extent, + ) + west_data = southwest_y_view.sel( **{ constants.X_INTERFACE_DIM: 0, constants.Y_DIM: slice( - 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] + 0, + self._qty_y_spec.extent[ + self._qty_y_spec.dims.index(constants.Y_DIM) + ], ), } ) west_data = rotate_scalar_data( west_data, [constants.Y_DIM], - y_quantity.np, + self._qty_y_spec.numpy_module, -west_boundary.n_clockwise_rotations, ) if west_boundary.n_clockwise_rotations in (1, 2): west_data = -west_data + + # Send requests send_requests = [ self._Isend( - self._maybe_force_cpu(x_quantity.np), + self._maybe_force_cpu(self._qty_x_spec.numpy_module), south_data, dest=south_boundary.to_rank, tag=tag, ), self._Isend( - self._maybe_force_cpu(y_quantity.np), + self._maybe_force_cpu(self._qty_y_spec.numpy_module), west_data, dest=west_boundary.to_rank, tag=tag, @@ -470,35 +502,61 @@ def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: return module def _Irecv_vector_shared_boundary( - self, x_quantity, y_quantity, tag=0 + self, x_array: np.ndarray, y_array: np.ndarray, tag=0 ) -> _HaloRequestRecvList: + # North boundary north_rank = self.boundaries[constants.NORTH].to_rank - east_rank = self.boundaries[constants.EAST].to_rank - north_data = x_quantity.view.northwest.sel( + northwest_x_view = BoundaryArrayView( + x_array, + constants.NORTHWEST, + self._qty_x_spec.dims, + self._qty_x_spec.origin, + self._qty_x_spec.extent, + ) + + north_data = northwest_x_view.sel( **{ constants.Y_INTERFACE_DIM: -1, constants.X_DIM: slice( - 0, x_quantity.extent[x_quantity.dims.index(constants.X_DIM)] + 0, + self._qty_x_spec.extent[ + self._qty_x_spec.dims.index(constants.X_DIM) + ], ), } ) - east_data = y_quantity.view.southeast.sel( + + # East boundary + east_rank = self.boundaries[constants.EAST].to_rank + southeast_y_view = BoundaryArrayView( + y_array, + constants.SOUTHEAST, + self._qty_y_spec.dims, + self._qty_y_spec.origin, + self._qty_y_spec.extent, + ) + east_data = southeast_y_view.sel( **{ constants.X_INTERFACE_DIM: -1, constants.Y_DIM: slice( - 0, y_quantity.extent[y_quantity.dims.index(constants.Y_DIM)] + 0, + self._qty_y_spec.extent[ + self._qty_y_spec.dims.index(constants.Y_DIM) + ], ), } ) + + # Receive requests recv_requests = [ self._Irecv( - self._maybe_force_cpu(x_quantity.np), + self._maybe_force_cpu(self._qty_x_spec.numpy_module), north_data, source=north_rank, tag=tag, ), self._Irecv( - self._maybe_force_cpu(y_quantity.np), + self._maybe_force_cpu(self._qty_y_spec.numpy_module), east_data, source=east_rank, tag=tag, From d68acbd9a0bb2dd2353fdb1e1fcd98d3d572945e Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 17 Aug 2022 06:15:44 -0700 Subject: [PATCH 2/8] HaloUpdater reasons on arrays & quantity halo spec only Communicator retains a Quantity API --- dsl/pace/dsl/dace/wrapped_halo_exchange.py | 12 +- pace-util/pace/util/communicator.py | 10 +- pace-util/pace/util/halo_data_transformer.py | 237 +++++++++--------- pace-util/pace/util/halo_updater.py | 50 ++-- pace-util/pace/util/rotate.py | 8 +- pace-util/tests/test_halo_data_transformer.py | 12 +- pace-util/tests/test_halo_update.py | 6 +- 7 files changed, 168 insertions(+), 167 deletions(-) diff --git a/dsl/pace/dsl/dace/wrapped_halo_exchange.py b/dsl/pace/dsl/dace/wrapped_halo_exchange.py index debe60d2d..02dbde828 100644 --- a/dsl/pace/dsl/dace/wrapped_halo_exchange.py +++ b/dsl/pace/dsl/dace/wrapped_halo_exchange.py @@ -34,22 +34,22 @@ def start(self): if self._qtx_y_names is None: if dataclasses.is_dataclass(self._state): self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names] + [self._state.__getattribute__(x).data for x in self._qtx_x_names] ) elif isinstance(self._state, dict): - self._updater.start([self._state[x] for x in self._qtx_x_names]) + self._updater.start([self._state[x].data for x in self._qtx_x_names]) else: raise NotImplementedError else: if dataclasses.is_dataclass(self._state): self._updater.start( - [self._state.__getattribute__(x) for x in self._qtx_x_names], - [self._state.__getattribute__(y) for y in self._qtx_y_names], + [self._state.__getattribute__(x).data for x in self._qtx_x_names], + [self._state.__getattribute__(y).data for y in self._qtx_y_names], ) elif isinstance(self._state, dict): self._updater.start( - [self._state[x] for x in self._qtx_x_names], - [self._state[y] for y in self._qtx_y_names], + [self._state[x].data for x in self._qtx_x_names], + [self._state[y].data for y in self._qtx_y_names], ) else: raise NotImplementedError diff --git a/pace-util/pace/util/communicator.py b/pace-util/pace/util/communicator.py index 47379dc14..3a8a96c8c 100644 --- a/pace-util/pace/util/communicator.py +++ b/pace-util/pace/util/communicator.py @@ -325,8 +325,10 @@ def start_halo_update( """ if isinstance(quantity, Quantity): quantities = [quantity] + arrays = [quantity.data] else: quantities = quantity + arrays = [qty.data for qty in quantity] specifications = [] for quantity in quantities: @@ -345,7 +347,7 @@ def start_halo_update( halo_updater = self.get_scalar_halo_updater(specifications) halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) + halo_updater.start(arrays) return halo_updater def vector_halo_update( @@ -397,12 +399,16 @@ def start_vector_halo_update( """ if isinstance(x_quantity, Quantity): x_quantities = [x_quantity] + x_arrays = [x_quantity.data] else: x_quantities = x_quantity + x_arrays = [qty.data for qty in x_quantity] if isinstance(y_quantity, Quantity): y_quantities = [y_quantity] + y_arrays = [y_quantity.data] else: y_quantities = y_quantity + y_arrays = [qty.data for qty in y_quantity] x_specifications = [] y_specifications = [] @@ -434,7 +440,7 @@ def start_vector_halo_update( halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) + halo_updater.start(x_arrays, y_arrays) return halo_updater def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): diff --git a/pace-util/pace/util/halo_data_transformer.py b/pace-util/pace/util/halo_data_transformer.py index cfaa0b763..c55d6e766 100644 --- a/pace-util/pace/util/halo_data_transformer.py +++ b/pace-util/pace/util/halo_data_transformer.py @@ -15,7 +15,6 @@ unpack_vector_f64_kernel, ) from .halo_quantity_specification import QuantityHaloSpec -from .quantity import Quantity from .rotate import rotate_scalar_data, rotate_vector_data from .types import NumpyModule from .utils import device_synchronize @@ -260,7 +259,7 @@ def get( ) raise NotImplementedError( - f"Quantity module {np_module} has no HaloDataTransformer implemented" + f"Numpy-like module {np_module} has no HaloDataTransformer implemented" ) def get_unpack_buffer(self) -> Buffer: @@ -311,41 +310,41 @@ def ready(self) -> bool: @abc.abstractmethod def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): - """Pack all given quantities into a single send Buffer. + """Pack all given arrays into a single send Buffer. Does not guarantee the buffer returned by `get_unpack_buffer` has received data, doing so requires calling `synchronize`. Reaching for the buffer via get_pack_buffer() will call synchronize(). Args: - quantities_x: scalar or vector x-component quantities to pack, + arrays_x: scalar or vector x-component data to pack, if one is vector they must all be vector - quantities_y: if quantities are vector, the y-component - quantities. + arrays_y: if data to exchange are vectors, the y-component + data. """ pass @abc.abstractmethod def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): - """Unpack the buffer into destination quantities. + """Unpack the buffer into destination arrays. Does not guarantee the buffer returned by `get_unpack_buffer` has received data, doing so requires calling `synchronize`. Reaching for the buffer via get_unpack_buffer() will call synchronize(). Args: - quantities_x: scalar or vector x-component quantities to be unpacked into, + arrays_x: scalar or vector x-component data to pack, if one is vector they must all be vector - quantities_y: if quantities are vector, the y-component - quantities. + arrays_y: if data to exchange are vectors, the y-component + data. """ pass @@ -372,41 +371,41 @@ def synchronize(self): def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._pack_scalar(quantities_x) + self._pack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._pack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._pack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} pack") assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened - def _pack_scalar(self, quantities: List[Quantity]): + def _pack_scalar(self, arrays: List[np.ndarray]): if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Arrays count ({len(arrays)}" f" is different that edges count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): + for array, info_x in zip(arrays, self._infos_x): data_size = _slices_size(info_x.pack_slices) # sending data across the boundary will rotate the data # n_clockwise_rotations times, due to the difference in axis orientation.\ # Thus we rotate that number of times counterclockwise before sending, # to get the right final orientation source_view = rotate_scalar_data( - quantity.data[info_x.pack_slices], - quantity.dims, - quantity.np, + array[info_x.pack_slices], + info_x.specification.dims, + info_x.specification.numpy_module, -info_x.pack_clockwise_rotation, ) self._pack_buffer.assign_from( @@ -415,38 +414,38 @@ def _pack_scalar(self, quantities: List[Quantity]): ) offset += data_size - def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity]): + def _pack_vector(self, arrays_x: List[np.ndarray], arrays_y: List[np.ndarray]): if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)})" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)})" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened - assert len(quantities_y) == len(quantities_x) + assert len(arrays_y) == len(arrays_x) assert len(self._infos_x) == len(self._infos_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): # sending data across the boundary will rotate the data # n_clockwise_rotations times, due to the difference in axis orientation # Thus we rotate that number of times counterclockwise before sending, # to get the right final orientation x_view, y_view = rotate_vector_data( - quantity_x.data[info_x.pack_slices], - quantity_y.data[info_y.pack_slices], + array_x[info_x.pack_slices], + array_y[info_y.pack_slices], -info_x.pack_clockwise_rotation, - quantity_x.dims, - quantity_x.np, + info_x.specification.dims, + info_x.specification.numpy_module, ) # Pack X/Y data slices in the buffer @@ -463,74 +462,72 @@ def _pack_vector(self, quantities_x: List[Quantity], quantities_y: List[Quantity def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._unpack_scalar(quantities_x) + self._unpack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._unpack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._unpack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} unpack") assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened - def _unpack_scalar(self, quantities: List[Quantity]): + def _unpack_scalar(self, arrays: List[np.ndarray]): if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Arrays count ({len(arrays)}" f" is different that specifications count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): - quantity_view = quantity.data[info_x.unpack_slices] + for array, info_x in zip(arrays, self._infos_x): + array_view = array[info_x.unpack_slices] data_size = _slices_size(info_x.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size - def _unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] - ): + def _unpack_vector(self, arrays_x: List[np.ndarray], arrays_y: List[np.ndarray]): if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)})" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)})" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)})" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity_x, quantity_y, info_x, info_y in zip( - quantities_x, quantities_y, self._infos_x, self._infos_y + for array_x, array_y, info_x, info_y in zip( + arrays_x, arrays_y, self._infos_x, self._infos_y ): - quantity_view = quantity_x.data[info_x.unpack_slices] + array_view = array_x[info_x.unpack_slices] data_size = _slices_size(info_x.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size - quantity_view = quantity_y.data[info_y.unpack_slices] + array_view = array_y[info_y.unpack_slices] data_size = _slices_size(info_y.unpack_slices) self._unpack_buffer.assign_to( - quantity_view, + array_view, buffer_slice=np.index_exp[offset : offset + data_size], - buffer_reshape=quantity_view.shape, + buffer_reshape=array_view.shape, ) offset += data_size @@ -539,7 +536,7 @@ class HaloDataTransformerGPU(HaloDataTransformer): """Pack/unpack data in a single buffer using CUDA Kernels. In order to efficiently pack/unpack on the GPU to a single GPU buffer - we use streamed (e.g. async) kernels per quantity per edge to send. The + we use streamed (e.g. async) kernels per array per edge to send. The kernels are store in `cuda_kernels.py`, they both follow the same simple pattern by reading the indices to the device memory of the data to pack/unpack. `_flatten_indices` is the routine that take the layout of the memory and @@ -669,47 +666,47 @@ def _get_stream(self, stream) -> "cp.cuda.stream": def async_pack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List["cp.ndarray"], + arrays_y: Optional[List["cp.ndarray"]] = None, ): - """Pack the quantities into a single buffer via streamed cuda kernels + """Pack the arrays into a single buffer via streamed cuda kernels Writes into self._pack_buffer using self._x_infos and self._y_infos - to read the offsets and sizes per quantity. + to read the offsets and sizes per array. Args: - quantities_x: list of quantities to pack. Must fit the specifications given + arrays_x: list of arrays to pack. Must fit the specifications given at init time. - quantities_y: Same as above but optional, used only for vector transfer. + arrays_y: Same as above but optional, used only for vector transfer. """ # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._opt_pack_scalar(quantities_x) + self._opt_pack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._opt_pack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._opt_pack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} pack") - def _opt_pack_scalar(self, quantities: List[Quantity]): + def _opt_pack_scalar(self, arrays: List["cp.ndarray"]): """Specialized packing for scalar. See async_pack docs for usage.""" if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)}" + f"Quantities count ({len(arrays)}" f" is different that specifications count {len(self._infos_x)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened offset = 0 - for info_x, quantity in zip(self._infos_x, quantities): + for info_x, array in zip(self._infos_x, arrays): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream with self._get_stream(cu_kernel_args.stream): - if quantity.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") # Launch kernel @@ -722,7 +719,7 @@ def _opt_pack_scalar(self, quantities: List[Quantity]): (grid_x,), (blocks,), ( - quantity.data[:], # source_array + array[:], # source_array cu_kernel_args.x_send_indices, # indices info_x.pack_buffer_size, # nIndex offset, @@ -734,29 +731,29 @@ def _opt_pack_scalar(self, quantities: List[Quantity]): offset += info_x.pack_buffer_size def _opt_pack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] + self, arrays_x: List["cp.ndarray"], arrays_y: List["cp.ndarray"] ): """Specialized packing for vectors. See async_pack docs for usage.""" if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)}" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)}" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._pack_buffer, Buffer) # e.g. allocate happened assert len(self._infos_x) == len(self._infos_y) - assert len(quantities_x) == len(quantities_y) + assert len(arrays_x) == len(arrays_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream @@ -765,7 +762,7 @@ def _opt_pack_vector( # Buffer sizes transformer_size = info_x.pack_buffer_size + info_y.pack_buffer_size - if quantity_x.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") # Launch kernel @@ -778,8 +775,8 @@ def _opt_pack_vector( (grid_x,), (blocks,), ( - quantity_x.data[:], # source_array_x - quantity_y.data[:], # source_array_y + array_x[:], # source_array_x + array_y[:], # source_array_y cu_kernel_args.x_send_indices, # indices_x cu_kernel_args.y_send_indices, # indices_y info_x.pack_buffer_size, # nIndex_x @@ -795,40 +792,40 @@ def _opt_pack_vector( def async_unpack( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List["cp.ndarray"], + arrays_y: Optional[List["cp.ndarray"]] = None, ): - """Unpack the quantities from a single buffer via streamed cuda kernels + """Unpack the arrays from a single buffer via streamed cuda kernels Reads from self._unpack_buffer using self._x_infos and self._y_infos - to read the offsets and sizes per quantity. + to read the offsets and sizes per array. Args: - quantities_x: list of quantities to unpack. Must fit + arrays_x: list of arrays to unpack. Must fit the specifications given at init time. - quantities_y: Same as above but optional, used only for vector transfer. + arrays_y: Same as above but optional, used only for vector transfer. """ # Unpack per type if self._type == _HaloDataTransformerType.SCALAR: - self._opt_unpack_scalar(quantities_x) + self._opt_unpack_scalar(arrays_x) elif self._type == _HaloDataTransformerType.VECTOR: - assert quantities_y is not None - self._opt_unpack_vector(quantities_x, quantities_y) + assert arrays_y is not None + self._opt_unpack_vector(arrays_x, arrays_y) else: raise RuntimeError(f"Unimplemented {self._type} unpack") - def _opt_unpack_scalar(self, quantities: List[Quantity]): + def _opt_unpack_scalar(self, arrays: List["cp.ndarray"]): """Specialized unpacking for scalars. See async_unpack docs for usage.""" if __debug__: - if len(quantities) != len(self._infos_x): + if len(arrays) != len(self._infos_x): raise RuntimeError( - f"Quantities count ({len(quantities)})" + f"Arrays count ({len(arrays)})" f" is different that specifications count ({len(self._infos_x)})" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened offset = 0 - for quantity, info_x in zip(quantities, self._infos_x): + for array, info_x in zip(arrays, self._infos_x): cu_kernel_args = self._cu_kernel_args[info_x._id] # Use private stream @@ -848,7 +845,7 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]): cu_kernel_args.x_recv_indices, # indices info_x._unpack_buffer_size, # nIndex offset, - quantity.data[:], # destination_array + array[:], # destination_array ), ) @@ -856,31 +853,31 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]): offset += info_x._unpack_buffer_size def _opt_unpack_vector( - self, quantities_x: List[Quantity], quantities_y: List[Quantity] + self, arrays_x: List["cp.ndarray"], arrays_y: List["cp.ndarray"] ): """Specialized unpacking for vectors. See async_unpack docs for usage.""" if __debug__: - if len(quantities_x) != len(self._infos_x) and len(quantities_y) != len( + if len(arrays_x) != len(self._infos_x) and len(arrays_y) != len( self._infos_y ): raise RuntimeError( - f"Quantities count (x: {len(quantities_x)}, y: {len(quantities_y)}" + f"Arrays count (x: {len(arrays_x)}, y: {len(arrays_y)}" " is different that specifications count " f"(x: {len(self._infos_x)}, y: {len(self._infos_y)}" ) - # TODO Per quantity check + # TODO Per array check assert isinstance(self._unpack_buffer, Buffer) # e.g. allocate happened assert len(self._infos_x) == len(self._infos_y) - assert len(quantities_x) == len(quantities_y) + assert len(arrays_x) == len(arrays_y) offset = 0 for ( - quantity_x, - quantity_y, + array_x, + array_y, info_x, info_y, - ) in zip(quantities_x, quantities_y, self._infos_x, self._infos_y): + ) in zip(arrays_x, arrays_y, self._infos_x, self._infos_y): # We only have writte a f64 kernel - if quantity_x.metadata.dtype != np.float64: + if info_x.specification.dtype != np.float64: raise RuntimeError(f"Kernel requires f64 given {np.float64}") cu_kernel_args = self._cu_kernel_args[info_x._id] @@ -907,8 +904,8 @@ def _opt_unpack_vector( info_x._unpack_buffer_size, # nIndex_x info_y._unpack_buffer_size, # nIndex_y offset, - quantity_x.data[:], # destination_array_x - quantity_y.data[:], # destination_array_y + array_x[:], # destination_array_x + array_y[:], # destination_array_y ), ) diff --git a/pace-util/pace/util/halo_updater.py b/pace-util/pace/util/halo_updater.py index 60cab950f..eff679f9e 100644 --- a/pace-util/pace/util/halo_updater.py +++ b/pace-util/pace/util/halo_updater.py @@ -9,7 +9,7 @@ from .buffer import Buffer from .halo_data_transformer import HaloDataTransformer, HaloExchangeSpec from .halo_quantity_specification import QuantityHaloSpec -from .quantity import BoundaryArrayView, Quantity +from .quantity import BoundaryArrayView from .rotate import rotate_scalar_data from .types import AsyncRequest, NumpyModule from .utils import device_synchronize @@ -61,8 +61,8 @@ def __init__( self._timer = timer self._recv_requests: List[AsyncRequest] = [] self._send_requests: List[AsyncRequest] = [] - self._inflight_x_quantities: Optional[Tuple[Quantity, ...]] = None - self._inflight_y_quantities: Optional[Tuple[Quantity, ...]] = None + self._inflight_x_arrays: Optional[Tuple[np.ndarray, ...]] = None + self._inflight_y_arrays: Optional[Tuple[np.ndarray, ...]] = None self._finalize_on_wait = False def force_finalize_on_wait(self): @@ -74,10 +74,7 @@ def force_finalize_on_wait(self): def __del__(self): """Clean up all buffers on garbage collection""" - if ( - self._inflight_x_quantities is not None - or self._inflight_y_quantities is not None - ): + if self._inflight_x_arrays is not None or self._inflight_y_arrays is not None: raise RuntimeError( "An halo exchange wasn't completed and a wait() call was expected" ) @@ -205,25 +202,22 @@ def from_vector_specifications( def update( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): """Exhange the data and blocks until finished.""" - self.start(quantities_x, quantities_y) + self.start(arrays_x, arrays_y) self.wait() def start( self, - quantities_x: List[Quantity], - quantities_y: Optional[List[Quantity]] = None, + arrays_x: List[np.ndarray], + arrays_y: Optional[List[np.ndarray]] = None, ): """Start data exchange.""" self._comm._device_synchronize() - if ( - self._inflight_x_quantities is not None - or self._inflight_y_quantities is not None - ): + if self._inflight_x_arrays is not None or self._inflight_y_arrays is not None: raise RuntimeError( "Previous exchange hasn't been properly finished." "E.g. previous start() call didn't have a wait() call." @@ -241,15 +235,13 @@ def start( ) ) - # Pack quantities halo points data into buffers + # Pack arrays halo points data into buffers with self._timer.clock("pack"): for transformer in self._transformers.values(): - transformer.async_pack(quantities_x, quantities_y) + transformer.async_pack(arrays_x, arrays_y) - self._inflight_x_quantities = tuple(quantities_x) - self._inflight_y_quantities = ( - tuple(quantities_y) if quantities_y is not None else None - ) + self._inflight_x_arrays = tuple(arrays_x) + self._inflight_y_arrays = tuple(arrays_y) if arrays_y is not None else None # Post send MPI order with self._timer.clock("Isend"): @@ -265,7 +257,7 @@ def start( def wait(self): """Finalize data exchange.""" - if __debug__ and self._inflight_x_quantities is None: + if __debug__ and self._inflight_x_arrays is None: raise RuntimeError('Halo update "wait" call before "start"') # Wait message to be exchange with self._timer.clock("wait"): @@ -275,12 +267,10 @@ def wait(self): recv_req.wait() # Unpack buffers (updated by MPI with neighbouring halos) - # to proper quantities + # to proper arrays with self._timer.clock("unpack"): for buffer in self._transformers.values(): - buffer.async_unpack( - self._inflight_x_quantities, self._inflight_y_quantities - ) + buffer.async_unpack(self._inflight_x_arrays, self._inflight_y_arrays) if self._finalize_on_wait: for transformer in self._transformers.values(): transformer.finalize() @@ -288,8 +278,8 @@ def wait(self): for transformer in self._transformers.values(): transformer.synchronize() - self._inflight_x_quantities = None - self._inflight_y_quantities = None + self._inflight_x_arrays = None + self._inflight_y_arrays = None class HaloUpdateRequest: @@ -397,7 +387,7 @@ def start_synchronize_vector_interfaces( For interface variables, the edges of the tile are computed on both ranks bordering that edge. This routine copies values across those shared edges so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. + rotation of vector data needed to move data across the edge. Args: x_array: the x-component data to be synchronized diff --git a/pace-util/pace/util/rotate.py b/pace-util/pace/util/rotate.py index 27ab1f252..0c15a60b4 100644 --- a/pace-util/pace/util/rotate.py +++ b/pace-util/pace/util/rotate.py @@ -1,7 +1,9 @@ from . import constants +from typing import List +import numpy as np -def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): +def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations) -> List[np.ndarray]: n_clockwise_rotations = n_clockwise_rotations % 4 if n_clockwise_rotations == 0: pass @@ -34,7 +36,9 @@ def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations): return data -def rotate_vector_data(x_data, y_data, n_clockwise_rotations, dims, numpy): +def rotate_vector_data( + x_data, y_data, n_clockwise_rotations, dims, numpy +) -> List[np.ndarray]: x_data = rotate_scalar_data(x_data, dims, numpy, n_clockwise_rotations) y_data = rotate_scalar_data(y_data, dims, numpy, n_clockwise_rotations) data = [x_data, y_data] diff --git a/pace-util/tests/test_halo_data_transformer.py b/pace-util/tests/test_halo_data_transformer.py index 8f512d08b..d6defc35b 100644 --- a/pace-util/tests/test_halo_data_transformer.py +++ b/pace-util/tests/test_halo_data_transformer.py @@ -326,12 +326,12 @@ def test_data_transformer_scalar_pack_unpack(quantity, rotation, n_halos): data_transformer = HaloDataTransformer.get(quantity.np, exchange_descriptors) - data_transformer.async_pack([quantity, quantity]) + data_transformer.async_pack([quantity.data, quantity.data]) # Simulate data transfer data_transformer.get_unpack_buffer().assign_from( data_transformer.get_pack_buffer().array ) - data_transformer.async_unpack([quantity, quantity]) + data_transformer.async_unpack([quantity.data, quantity.data]) data_transformer.synchronize() # From the copy of the original quantity we rotate data @@ -433,12 +433,16 @@ def test_data_transformer_vector_pack_unpack(quantity, rotation, n_halos): x_quantity.np, exchange_descriptors_x, exchange_descriptors_y ) - data_transformer.async_pack([x_quantity, x_quantity], [y_quantity, y_quantity]) + data_transformer.async_pack( + [x_quantity.data, x_quantity.data], [y_quantity.data, y_quantity.data] + ) # Simulate data transfer data_transformer.get_unpack_buffer().assign_from( data_transformer.get_pack_buffer().array ) - data_transformer.async_unpack([x_quantity, x_quantity], [y_quantity, y_quantity]) + data_transformer.async_unpack( + [x_quantity.data, x_quantity.data], [y_quantity.data, y_quantity.data] + ) data_transformer.synchronize() # From the copy of the original quantity we rotate data diff --git a/pace-util/tests/test_halo_update.py b/pace-util/tests/test_halo_update.py index 834c1beb1..8cdfefca7 100644 --- a/pace-util/tests/test_halo_update.py +++ b/pace-util/tests/test_halo_update.py @@ -866,7 +866,7 @@ def test_halo_updater_stability( # First run for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() @@ -874,11 +874,11 @@ def test_halo_updater_stability( # The buffer should stay stable since we are exchanging the same information exchanged_once_quantity = copy.deepcopy(quantity) for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() for halo_updater in halo_updaters: - halo_updater.start([quantity]) + halo_updater.start([quantity.data]) for halo_updater in halo_updaters: halo_updater.wait() assert (quantity.data == exchanged_once_quantity.data).all() From 1eed1817230eb8d17837c1583cdf1eb0f97d55eb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 23 Aug 2022 12:39:55 -0700 Subject: [PATCH 3/8] Remove tracers Add helper function on DycoreState --- driver/pace/driver/driver.py | 3 +- dsl/pace/dsl/dace/wrapped_halo_exchange.py | 53 ++++----------- dsl/pace/dsl/gt4py_utils.py | 6 ++ examples/notebooks/functions.py | 7 +- .../examples/standalone/runfile/acoustics.py | 1 - .../examples/standalone/runfile/dynamics.py | 3 +- .../fv3core/initialization/dycore_state.py | 19 +++++- fv3core/fv3core/stencils/dyn_core.py | 66 +++++++------------ fv3core/fv3core/stencils/fillz.py | 11 +--- fv3core/fv3core/stencils/fv_dynamics.py | 56 +++++++--------- fv3core/fv3core/stencils/mapn_tracer.py | 7 +- fv3core/fv3core/stencils/remapping.py | 6 +- fv3core/fv3core/stencils/tracer_2d_1l.py | 14 ++-- fv3core/fv3core/testing/translate_dyncore.py | 1 - .../fv3core/testing/translate_fvdynamics.py | 7 +- fv3core/tests/mpi/test_doubly_periodic.py | 10 ++- .../translate/translate_cubedtolatlon.py | 1 - .../savepoint/translate/translate_fillz.py | 1 - .../translate/translate_mapn_tracer_2d.py | 1 - .../translate/translate_remapping.py | 1 - .../translate/translate_tracer2d1l.py | 6 +- .../translate/translate_fv_update_phys.py | 3 - stencils/pace/stencils/c2l_ord.py | 10 +-- stencils/pace/stencils/fv_update_phys.py | 14 +--- stencils/pace/stencils/update_atmos_state.py | 9 --- tests/main/fv3core/test_dycore_call.py | 9 ++- 26 files changed, 125 insertions(+), 200 deletions(-) diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index b0de12c18..5548ec6b1 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -274,7 +274,6 @@ def exit_function(*args, **kwargs): damping_coefficients=self.state.damping_coefficients, config=self.config.dycore_config, phis=self.state.dycore_state.phis, - state=self.state.dycore_state, ) self.dycore.update_state( @@ -303,7 +302,6 @@ def exit_function(*args, **kwargs): namelist=self.config.physics_config, comm=communicator, grid_info=self.state.driver_grid_data, - state=self.state.dycore_state, quantity_factory=self.quantity_factory, dycore_only=self.config.dycore_only, apply_tendencies=self.config.apply_tendencies, @@ -425,6 +423,7 @@ def _step_dynamics( ): self.dycore.step_dynamics( state=state, + tracers_dict=state.tracers_as_array(), timer=timer, ) diff --git a/dsl/pace/dsl/dace/wrapped_halo_exchange.py b/dsl/pace/dsl/dace/wrapped_halo_exchange.py index 02dbde828..516ef8b32 100644 --- a/dsl/pace/dsl/dace/wrapped_halo_exchange.py +++ b/dsl/pace/dsl/dace/wrapped_halo_exchange.py @@ -1,6 +1,8 @@ import dataclasses from typing import List, Optional, Union +import numpy as np + from pace.dsl.dace.orchestration import dace_inhibitor from pace.util.communicator import CubedSphereCommunicator from pace.util.halo_updater import HaloUpdater, VectorInterfaceHaloUpdater @@ -18,58 +20,29 @@ class WrappedHaloUpdater: def __init__( self, updater: Union[HaloUpdater, VectorInterfaceHaloUpdater], - state, - qty_x_names: List[str], - qty_y_names: List[str] = None, - comm: Optional[CubedSphereCommunicator] = None, ) -> None: self._updater = updater - self._state = state - self._qtx_x_names = qty_x_names - self._qtx_y_names = qty_y_names - self._comm = comm @dace_inhibitor - def start(self): - if self._qtx_y_names is None: - if dataclasses.is_dataclass(self._state): - self._updater.start( - [self._state.__getattribute__(x).data for x in self._qtx_x_names] - ) - elif isinstance(self._state, dict): - self._updater.start([self._state[x].data for x in self._qtx_x_names]) - else: - raise NotImplementedError - else: - if dataclasses.is_dataclass(self._state): - self._updater.start( - [self._state.__getattribute__(x).data for x in self._qtx_x_names], - [self._state.__getattribute__(y).data for y in self._qtx_y_names], - ) - elif isinstance(self._state, dict): - self._updater.start( - [self._state[x].data for x in self._qtx_x_names], - [self._state[y].data for y in self._qtx_y_names], - ) - else: - raise NotImplementedError + def start( + self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None + ): + assert isinstance(self._updater, HaloUpdater) + self._updater.start(arrays_x, arrays_y) @dace_inhibitor def wait(self): self._updater.wait() @dace_inhibitor - def update(self): - self.start() + def update( + self, arrays_x: List[np.ndarray], arrays_y: Optional[List[np.ndarray]] = None + ): + self.start(arrays_x, arrays_y) self.wait() @dace_inhibitor - def interface(self): + def interface(self, arrays_x: np.ndarray, arrays_y: np.ndarray): assert isinstance(self._updater, VectorInterfaceHaloUpdater) - assert len(self._qtx_x_names) == 1 - assert len(self._qtx_y_names) == 1 - request = self._updater.start_synchronize_vector_interfaces( - self._state.__getattribute__(self._qtx_x_names[0]).data, - self._state.__getattribute__(self._qtx_y_names[0]).data, - ) + request = self._updater.start_synchronize_vector_interfaces(arrays_x, arrays_y) request.wait() diff --git a/dsl/pace/dsl/gt4py_utils.py b/dsl/pace/dsl/gt4py_utils.py index cbd0d51fd..6fbefcac4 100644 --- a/dsl/pace/dsl/gt4py_utils.py +++ b/dsl/pace/dsl/gt4py_utils.py @@ -21,6 +21,12 @@ halo = 3 origin = (halo, halo, 0) +# nq is actually given by ncnst - pnats, where those are given in atmosphere.F90 by: +# ncnst = Atm(mytile)%ncnst +# pnats = Atm(mytile)%flagstruct%pnats +# here we hard-coded it because 8 is the only supported value, refactor this later! +NQ = 8 # state.nq_tot - spec.namelist.dnats + # TODO get from field_table tracer_variables = [ "qvapor", diff --git a/examples/notebooks/functions.py b/examples/notebooks/functions.py index 756cf4d31..7b5e42700 100644 --- a/examples/notebooks/functions.py +++ b/examples/notebooks/functions.py @@ -917,9 +917,7 @@ def run_finite_volume_fluxprep( return flux_prep -def build_tracer_advection( - stencil_configuration: Dict[str, Any], tracers: Dict[str, Quantity] -) -> TracerAdvection: +def build_tracer_advection(stencil_configuration: Dict[str, Any]) -> TracerAdvection: """ Use: tracer_advection = build_tracer_advection(stencil_configuration, tracers) @@ -949,7 +947,6 @@ def build_tracer_advection( fvtp_2d, stencil_configuration["grid_data"], stencil_configuration["communicator"], - tracers, ) return tracer_advection @@ -993,7 +990,7 @@ def prepare_everything_for_advection( timestep, ) - tracer_advection = build_tracer_advection(stencil_configuration, tracers) + tracer_advection = build_tracer_advection(stencil_configuration) tracer_advection_data = { "tracers": tracers, diff --git a/fv3core/examples/standalone/runfile/acoustics.py b/fv3core/examples/standalone/runfile/acoustics.py index 360cfc9b0..a82018890 100755 --- a/fv3core/examples/standalone/runfile/acoustics.py +++ b/fv3core/examples/standalone/runfile/acoustics.py @@ -180,7 +180,6 @@ def driver( dycore_config.acoustic_dynamics, input_data["pfull"], input_data["phis"], - state, ) # warm-up timestep. diff --git a/fv3core/examples/standalone/runfile/dynamics.py b/fv3core/examples/standalone/runfile/dynamics.py index 0ca8bc4bd..127271404 100755 --- a/fv3core/examples/standalone/runfile/dynamics.py +++ b/fv3core/examples/standalone/runfile/dynamics.py @@ -263,7 +263,6 @@ def setup_dycore( damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=dycore_config, phis=state.phis, - state=state, ) dycore.update_state( conserve_total_energy=dycore_config.consv_te, @@ -310,7 +309,7 @@ def setup_dycore( # warmup/compilation from the internal timers if rank == 0: print("timestep 1") - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) if profiler is not None: profiler.enable() diff --git a/fv3core/fv3core/initialization/dycore_state.py b/fv3core/fv3core/initialization/dycore_state.py index 4baf2f90a..b846864ac 100644 --- a/fv3core/fv3core/initialization/dycore_state.py +++ b/fv3core/fv3core/initialization/dycore_state.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field, fields -from typing import Any, Mapping +from typing import Any, Dict, List, Mapping +import numpy as np import xarray as xr import pace.dsl.gt4py_utils as gt_utils @@ -372,5 +373,21 @@ def xr_dataset(self): ) return xr.Dataset(data_vars=data_vars) + @property + def tracers(self) -> List[pace.util.Quantity]: + return [self.__getattribute__(x) for x in DycoreState.tracer_names()] + + def tracers_as_array(self) -> Dict[str, np.ndarray]: + all_tracers = { + name: self.__getattribute__(name).data + for name in DycoreState.tracer_names() + } + all_tracers.pop("qcld") + return all_tracers + + @classmethod + def tracer_names(cls) -> List[str]: + return gt_utils.tracer_variables + def __getitem__(self, item): return getattr(self, item) diff --git a/fv3core/fv3core/stencils/dyn_core.py b/fv3core/fv3core/stencils/dyn_core.py index 2b474d436..73e48d68b 100644 --- a/fv3core/fv3core/stencils/dyn_core.py +++ b/fv3core/fv3core/stencils/dyn_core.py @@ -245,7 +245,6 @@ def __init__( comm: pace.util.CubedSphereCommunicator, grid_indexing: GridIndexing, backend: str, - state, ): origin = grid_indexing.origin_compute() shape = grid_indexing.max_shape @@ -296,51 +295,32 @@ def __init__( # quantities at runtime paradigm self.q_con__cappa = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 2), - state, - ["q_con", "cappa"], ) self.delp__pt = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 2), - state, - ["delp", "pt"], ) self.u__v = WrappedHaloUpdater( comm.get_vector_halo_updater( [full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec] ), - state, - ["u"], - ["v"], ) self.w = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec]), - state, - ["w"], ) self.gz = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]), - state, - ["gz"], ) self.delp__pt__q_con = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec] * 3), - state, - ["delp", "pt", "q_con"], ) self.zh = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]), - state, - ["zh"], ) self.divgd = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xiyiz_halo_spec]), - state, - ["divgd"], ) self.heat_source = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_size_xyz_halo_spec]), - state, - ["heat_source"], ) if grid_indexing.domain[0] == grid_indexing.domain[1]: full_3Dfield_2pts_halo_spec = grid_indexing.get_quantity_halo_spec( @@ -352,21 +332,20 @@ def __init__( ) self.pkc = WrappedHaloUpdater( comm.get_scalar_halo_updater([full_3Dfield_2pts_halo_spec]), - state, - ["pkc"], ) else: self.pkc = comm.get_scalar_halo_updater([full_size_xyzi_halo_spec]) self.uc__vc = WrappedHaloUpdater( comm.get_vector_halo_updater( - [full_size_xiyz_halo_spec], [full_size_xyiz_halo_spec] + [full_size_xiyz_halo_spec], + [full_size_xyiz_halo_spec], ), - state, - ["uc"], - ["vc"], ) self.interface_uc__vc = WrappedHaloUpdater( - None, state, ["u"], ["v"], comm=comm + comm.get_vector_interface_halo_updater( + full_size_xyiz_halo_spec, + full_size_xiyz_halo_spec, + ) ) def __init__( @@ -381,7 +360,6 @@ def __init__( config: AcousticDynamicsConfig, pfull: FloatFieldK, phis: FloatFieldIJ, - state, # [DaCe] hack to get around quantity as parameters for halo updates checkpointer: Optional[pace.util.Checkpointer] = None, ): """ @@ -621,7 +599,7 @@ def __init__( # Halo updaters self._halo_updaters = AcousticDynamics._HaloUpdaters( - comm, grid_indexing, stencil_factory.backend, state + comm, grid_indexing, stencil_factory.backend ) def _checkpoint_csw(self, state, tag: str): @@ -704,9 +682,9 @@ def __call__( # m_split = 1. + abs(dt_atmos)/real(k_split*n_split*abs(p_split)) # n_split = nint( real(n0split)/real(k_split*abs(p_split)) * stretch_fac + 0.5 ) # NOTE: In Fortran model the halo update starts happens in fv_dynamics, not here - self._halo_updaters.q_con__cappa.start() - self._halo_updaters.delp__pt.start() - self._halo_updaters.u__v.start() + self._halo_updaters.q_con__cappa.start([state.q_con.data, state.cappa.data]) + self._halo_updaters.delp__pt.start([state.delp.data, state.pt.data]) + self._halo_updaters.u__v.start([state.u.data], [state.v.data]) self._halo_updaters.q_con__cappa.wait() if update_temporaries: @@ -741,14 +719,14 @@ def __call__( if self.config.breed_vortex_inline or (it == n_split - 1): remap_step = True if not self.config.hydrostatic: - self._halo_updaters.w.start() + self._halo_updaters.w.start([state.w.data]) if it == 0: self._set_gz( self._zs, state.delz, state.gz, ) - self._halo_updaters.gz.start() + self._halo_updaters.gz.start([state.gz.data]) if it == 0: self._halo_updaters.delp__pt.wait() @@ -785,7 +763,7 @@ def __call__( self._checkpoint_csw(state, tag="Out") if self.config.nord > 0: - self._halo_updaters.divgd.start() + self._halo_updaters.divgd.start([state.divgd.data]) if not self.config.hydrostatic: if it == 0: self._halo_updaters.gz.wait() @@ -826,7 +804,7 @@ def __call__( state.gz, dt2, ) - self._halo_updaters.uc__vc.start() + self._halo_updaters.uc__vc.start([state.uc.data], [state.vc.data]) if self.config.nord > 0: self._halo_updaters.divgd.wait() self._halo_updaters.uc__vc.wait() @@ -863,7 +841,9 @@ def __call__( # note that uc and vc are not needed at all past this point. # they will be re-computed from scratch on the next acoustic timestep. - self._halo_updaters.delp__pt__q_con.update() + self._halo_updaters.delp__pt__q_con.update( + [state.delp.data, state.pt.data, state.q_con.data] + ) # Not used unless we implement other betas and alternatives to nh_p_grad # if self.namelist.d_ext > 0: @@ -901,8 +881,8 @@ def __call__( state.w, ) - self._halo_updaters.zh.start() - self._halo_updaters.pkc.start() + self._halo_updaters.zh.start([state.zh.data]) + self._halo_updaters.pkc.start([state.pkc.data]) if remap_step: self._edge_pe_stencil(state.pe, state.delp, self._ptop) if self.config.use_logp: @@ -949,13 +929,15 @@ def __call__( # [DaCe] this should be a reuse of # self._halo_updaters.u__v but it creates # parameter generation issues, and therefore has been duplicated - self._halo_updaters.u__v.start() + self._halo_updaters.u__v.start([state.u.data], [state.v.data]) else: if self.config.grid_type < 4: - self._halo_updaters.interface_uc__vc.interface() + self._halo_updaters.interface_uc__vc.interface( + state.u.data, state.v.data + ) if self._do_del2cubed: - self._halo_updaters.heat_source.update() + self._halo_updaters.heat_source.update([state.heat_source.data]) # TODO: move dependence on da_min into init of hyperdiffusion class cd = constants.CNST_0P20 * self._da_min self._hyperdiffusion(state.heat_source, cd) diff --git a/fv3core/fv3core/stencils/fillz.py b/fv3core/fv3core/stencils/fillz.py index c2d64ad98..851aeab42 100644 --- a/fv3core/fv3core/stencils/fillz.py +++ b/fv3core/fv3core/stencils/fillz.py @@ -1,13 +1,13 @@ import typing from typing import Dict +import numpy as np from gt4py.gtscript import BACKWARD, FORWARD, PARALLEL, computation, interval import pace.dsl.gt4py_utils as utils from pace.dsl.dace import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ, IntFieldIJ -from pace.util import Quantity @typing.no_type_check @@ -125,7 +125,6 @@ def __init__( jm: int, km: int, nq: int, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -155,21 +154,17 @@ def make_storage(*args, **kwargs): self._sum0 = make_storage(shape_ij, origin=(0, 0)) self._sum1 = make_storage(shape_ij, origin=(0, 0)) - self._filtered_tracer_dict = { - name: tracers[name] for name in utils.tracer_variables[0 : self._nq] - } - def __call__( self, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], ): """ Args: dp2 (in): pressure thickness of atmospheric layer tracers (inout): tracers to fix negative masses in """ - for tracer_name in self._filtered_tracer_dict.keys(): + for tracer_name in utils.tracer_variables[0 : self._nq]: self._fix_tracer_stencil( tracers[tracer_name], dp2, diff --git a/fv3core/fv3core/stencils/fv_dynamics.py b/fv3core/fv3core/stencils/fv_dynamics.py index 10eda2093..baabdd016 100644 --- a/fv3core/fv3core/stencils/fv_dynamics.py +++ b/fv3core/fv3core/stencils/fv_dynamics.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional +import numpy as np from dace.frontend.python.interface import nounroll as dace_no_unroll from gt4py.gtscript import PARALLEL, computation, interval, log @@ -23,14 +24,6 @@ from pace.util import Timer from pace.util.grid import DampingCoefficients, GridData from pace.util.mpi import MPI -from pace.util.quantity import Quantity - - -# nq is actually given by ncnst - pnats, where those are given in atmosphere.F90 by: -# ncnst = Atm(mytile)%ncnst -# pnats = Atm(mytile)%flagstruct%pnats -# here we hard-coded it because 8 is the only supported value, refactor this later! -NQ = 8 # state.nq_tot - spec.namelist.dnats def pt_adjust(pkz: FloatField, dp1: FloatField, q_con: FloatField, pt: FloatField): @@ -105,7 +98,6 @@ def __init__( damping_coefficients: DampingCoefficients, config: DynamicalCoreConfig, phis: pace.util.Quantity, - state: DycoreState, checkpointer: Optional[pace.util.Checkpointer] = None, ): """ @@ -210,19 +202,14 @@ def __init__( hord=config.hord_tr, ) - self.tracers = {} - for name in utils.tracer_variables[0:NQ]: - self.tracers[name] = state.__dict__[name] - self.tracer_storages = { - name: quantity.storage for name, quantity in self.tracers.items() - } - self._temporaries = fvdyn_temporaries(quantity_factory) - state.__dict__.update(self._temporaries) # Build advection stencils self.tracer_advection = tracer_2d_1l.TracerAdvection( - stencil_factory, tracer_transport, self.grid_data, comm, self.tracers + stencil_factory, + tracer_transport, + self.grid_data, + comm, ) self._ak = grid_data.ak self._bk = grid_data.bk @@ -274,7 +261,6 @@ def __init__( self.config.acoustic_dynamics, self._pfull, self._phis, - state, checkpointer=checkpointer, ) self._hyperdiffusion = HyperdiffusionDamping( @@ -284,7 +270,7 @@ def __init__( self.config.nf_omega, ) self._cubed_to_latlon = CubedToLatLon( - state, stencil_factory, grid_data, config.c2l_ord, comm + stencil_factory, grid_data, config.c2l_ord, comm ) self._temporaries = fvdyn_temporaries(quantity_factory) @@ -293,7 +279,7 @@ def __init__( # if self._temporaries were a dataclass we can remove this for name, value in self._temporaries.items(): setattr(self, f"_tmp_{name}", value) - if not (not self.config.inline_q and NQ != 0): + if not (not self.config.inline_q and utils.NQ != 0): raise NotImplementedError("tracer_2d not implemented, turn on z_tracer") self._adjust_tracer_mixing_ratio = AdjustNegativeTracerMixingRatio( stencil_factory, @@ -305,9 +291,8 @@ def __init__( stencil_factory, config.remapping, grid_data.area_64, - NQ, + utils.NQ, self._pfull, - tracers=self.tracers, ) full_xyz_spec = grid_indexing.get_quantity_halo_spec( @@ -318,7 +303,7 @@ def __init__( backend=stencil_factory.backend, ) self._omega_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([full_xyz_spec]), state, ["omga"], comm=comm + comm.get_scalar_halo_updater([full_xyz_spec]) ) def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): @@ -369,7 +354,9 @@ def update_state( state.__dict__.update(self._temporaries) state.__dict__.update(self.acoustic_dynamics._temporaries) - def step_dynamics(self, state: DycoreState, timer: Timer): + def step_dynamics( + self, state: DycoreState, tracers_dict: Dict[str, np.ndarray], timer: Timer + ): """ Step the model state forward by one timestep. @@ -378,7 +365,7 @@ def step_dynamics(self, state: DycoreState, timer: Timer): state: model prognostic state and inputs """ self._checkpoint_fvdynamics(state=state, tag="In") - self._compute(state, timer) + self._compute(state, tracers_dict, timer) self._checkpoint_fvdynamics(state=state, tag="Out") def compute_preamble(self, state, is_root_rank: bool): @@ -428,7 +415,12 @@ def compute_preamble(self, state, is_root_rank: bool): def __call__(self, *args, **kwargs): return self.step_dynamics(*args, **kwargs) - def _compute(self, state, timer: pace.util.Timer): + def _compute( + self, + state: DycoreState, + tracers_dict: Dict[str, np.ndarray], + timer: pace.util.Timer, + ): last_step = False self.compute_preamble( state, @@ -438,7 +430,7 @@ def _compute(self, state, timer: pace.util.Timer): for k_split in dace_no_unroll(range(state.k_split)): n_map = k_split + 1 last_step = k_split == state.k_split - 1 - self._dyn(state=state, tracers=self.tracers, n_map=n_map, timer=timer) + self._dyn(state=state, tracers=tracers_dict, n_map=n_map, timer=timer) if self.grid_indexing.domain[2] > 4: # nq is actually given by ncnst - pnats, @@ -455,7 +447,7 @@ def _compute(self, state, timer: pace.util.Timer): log_on_rank_0("Remapping") with timer.clock("Remapping"): self._lagrangian_to_eulerian_obj( - self.tracer_storages, + tracers_dict, state.pt, state.delp, state.delz, @@ -503,7 +495,7 @@ def _compute(self, state, timer: pace.util.Timer): def _dyn( self, state, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], n_map, timer: pace.util.Timer, ): @@ -547,7 +539,7 @@ def post_remap( ) if self.config.nf_omega > 0: log_on_rank_0("Del2Cubed") - self._omega_halo_updater.update() + self._omega_halo_updater.update([state.omga.data]) self._hyperdiffusion(state.omga, 0.18 * da_min) def wrapup( diff --git a/fv3core/fv3core/stencils/mapn_tracer.py b/fv3core/fv3core/stencils/mapn_tracer.py index 880633c7b..fa86a1d4a 100644 --- a/fv3core/fv3core/stencils/mapn_tracer.py +++ b/fv3core/fv3core/stencils/mapn_tracer.py @@ -1,12 +1,13 @@ from typing import Dict +import numpy as np + import pace.dsl.gt4py_utils as utils from fv3core.stencils.fillz import FillNegativeTracerValues from fv3core.stencils.map_single import MapSingle from pace.dsl.dace.orchestration import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField -from pace.util import Quantity class MapNTracer: @@ -24,7 +25,6 @@ def __init__( j1: int, j2: int, fill: bool, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -62,7 +62,6 @@ def __init__( self._list_of_remap_objects[0].j_extent, self._nk, self._nq, - tracers, ) else: self._fill_negative_tracers = False @@ -72,7 +71,7 @@ def __call__( pe1: FloatField, pe2: FloatField, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], ): """ Remaps the tracer species onto the Eulerian grid diff --git a/fv3core/fv3core/stencils/remapping.py b/fv3core/fv3core/stencils/remapping.py index c081f1a56..1ab349a9b 100644 --- a/fv3core/fv3core/stencils/remapping.py +++ b/fv3core/fv3core/stencils/remapping.py @@ -1,5 +1,6 @@ from typing import Dict +import numpy as np from gt4py.gtscript import ( __INLINED, BACKWARD, @@ -24,7 +25,6 @@ from pace.dsl.dace.orchestration import orchestrate from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ, FloatFieldK -from pace.util import Quantity # TODO: Should this be set here or in global_constants? @@ -285,7 +285,6 @@ def __init__( area_64, nq, pfull, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -372,7 +371,6 @@ def __init__( grid_indexing.jsc, grid_indexing.jec, fill=config.fill, - tracers=tracers, ) self._map_single_w = MapSingle( @@ -486,7 +484,7 @@ def __init__( def __call__( self, - tracers: Dict[str, Quantity], + tracers: Dict[str, np.ndarray], pt: FloatField, delp: FloatField, delz: FloatField, diff --git a/fv3core/fv3core/stencils/tracer_2d_1l.py b/fv3core/fv3core/stencils/tracer_2d_1l.py index b86ce1db8..a3fde159a 100644 --- a/fv3core/fv3core/stencils/tracer_2d_1l.py +++ b/fv3core/fv3core/stencils/tracer_2d_1l.py @@ -2,6 +2,7 @@ from typing import Dict import gt4py.gtscript as gtscript +import numpy as np from gt4py.gtscript import PARALLEL, computation, horizontal, interval, region import pace.dsl.gt4py_utils as utils @@ -11,7 +12,6 @@ from pace.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from pace.dsl.stencil import StencilFactory from pace.dsl.typing import FloatField, FloatFieldIJ -from pace.util import Quantity @gtscript.function @@ -182,7 +182,6 @@ def __init__( transport: FiniteVolumeTransport, grid_data, comm: pace.util.CubedSphereCommunicator, - tracers: Dict[str, Quantity], ): orchestrate( obj=self, @@ -191,7 +190,6 @@ def __init__( ) grid_indexing = stencil_factory.grid_indexing self.grid_indexing = grid_indexing # needed for selective validation - self._tracer_count = len(tracers) self.grid_data = grid_data shape = grid_indexing.domain_full(add=(1, 1, 1)) origin = grid_indexing.origin_compute() @@ -271,12 +269,10 @@ def make_storage(): backend=stencil_factory.backend, ) self._tracers_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([tracer_halo_spec] * self._tracer_count), - tracers, - [t for t in tracers.keys()], + comm.get_scalar_halo_updater([tracer_halo_spec] * utils.NQ), ) - def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt): + def __call__(self, tracers: Dict[str, np.ndarray], dp1, mfxd, mfyd, cxd, cyd, mdt): """ Args: tracers (inout): @@ -352,7 +348,7 @@ def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt) n_split, ) - self._tracers_halo_updater.update() + self._tracers_halo_updater.update(tracers.values()) dp2 = self._tmp_dp @@ -386,6 +382,6 @@ def __call__(self, tracers: Dict[str, Quantity], dp1, mfxd, mfyd, cxd, cyd, mdt) dp2, ) if not last_call: - self._tracers_halo_updater.update() + self._tracers_halo_updater.update(tracers.values()) # use variable assignment to avoid a data copy self._swap_dp(dp1, dp2) diff --git a/fv3core/fv3core/testing/translate_dyncore.py b/fv3core/fv3core/testing/translate_dyncore.py index 27f18b18c..374dc4389 100644 --- a/fv3core/fv3core/testing/translate_dyncore.py +++ b/fv3core/fv3core/testing/translate_dyncore.py @@ -171,7 +171,6 @@ def compute_parallel(self, inputs, communicator): config=DynamicalCoreConfig.from_namelist(self.namelist).acoustic_dynamics, pfull=inputs["pfull"], phis=inputs["phis"], - state=state, ) state.__dict__.update(acoustic_dynamics._temporaries) acoustic_dynamics(state, n_map=state.n_map, update_temporaries=False) diff --git a/fv3core/fv3core/testing/translate_fvdynamics.py b/fv3core/fv3core/testing/translate_fvdynamics.py index 9c8e07faa..a49f2909d 100644 --- a/fv3core/fv3core/testing/translate_fvdynamics.py +++ b/fv3core/fv3core/testing/translate_fvdynamics.py @@ -13,7 +13,7 @@ from pace.stencils.testing.translate import TranslateFortranData2Py -ADVECTED_TRACER_NAMES = utils.tracer_variables[: fv_dynamics.NQ] +ADVECTED_TRACER_NAMES = utils.tracer_variables[: utils.NQ] class TranslateDycoreFortranData2Py(TranslateFortranData2Py): @@ -330,7 +330,6 @@ def compute_parallel(self, inputs, communicator): damping_coefficients=self.grid.damping_coefficients, config=DynamicalCoreConfig.from_namelist(self.namelist), phis=state.phis, - state=state, ) self.dycore.update_state( self.namelist.consv_te, @@ -339,7 +338,9 @@ def compute_parallel(self, inputs, communicator): self.namelist.n_split, state, ) - self.dycore.step_dynamics(state, pace.util.NullTimer()) + self.dycore.step_dynamics( + state, state.tracers_as_array(), pace.util.NullTimer() + ) outputs = self.outputs_from_state(state) for name, value in outputs.items(): outputs[name] = self.subset_output(name, value) diff --git a/fv3core/tests/mpi/test_doubly_periodic.py b/fv3core/tests/mpi/test_doubly_periodic.py index 3949e3611..98a1a0a0b 100644 --- a/fv3core/tests/mpi/test_doubly_periodic.py +++ b/fv3core/tests/mpi/test_doubly_periodic.py @@ -109,18 +109,22 @@ def setup_dycore() -> Tuple[fv3core.DynamicalCore, List[Any]]: damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=config, phis=state.phis, - state=state, ) do_adiabatic_init = False # TODO compute from namelist bdt = config.dt_atmos - args = [ - state, + dycore.update_state( config.consv_te, do_adiabatic_init, bdt, config.n_split, + state, + ) + + args = [ + state, + state.tracers_as_array(), ] return dycore, args diff --git a/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py b/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py index c25167255..baf960e8a 100644 --- a/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py +++ b/fv3core/tests/savepoint/translate/translate_cubedtolatlon.py @@ -47,7 +47,6 @@ def compute_parallel(self, inputs, communicator): state_dict = {"u": u_quantity, "v": v_quantity} self._cubed_to_latlon = CubedToLatLon( - state=state_dict, stencil_factory=self.stencil_factory, grid_data=self.grid.grid_data, order=self.namelist.c2l_ord, diff --git a/fv3core/tests/savepoint/translate/translate_fillz.py b/fv3core/tests/savepoint/translate/translate_fillz.py index 3093ea6a4..590f0d263 100644 --- a/fv3core/tests/savepoint/translate/translate_fillz.py +++ b/fv3core/tests/savepoint/translate/translate_fillz.py @@ -74,7 +74,6 @@ def compute(self, inputs): inputs.pop("jm"), inputs.pop("km"), inputs.pop("nq"), - inputs["tracers"], ) run_fillz(**inputs) ds = self.grid.default_domain_dict() diff --git a/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py b/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py index 27d885d4e..5f51257a2 100644 --- a/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py +++ b/fv3core/tests/savepoint/translate/translate_mapn_tracer_2d.py @@ -66,7 +66,6 @@ def compute(self, inputs): inputs.pop("j1"), inputs.pop("j2"), fill=self.namelist.fill, - tracers=inputs["tracers"], ) self.compute_func(**inputs) return self.slice_output(inputs) diff --git a/fv3core/tests/savepoint/translate/translate_remapping.py b/fv3core/tests/savepoint/translate/translate_remapping.py index 01d3bce26..7a70655f7 100644 --- a/fv3core/tests/savepoint/translate/translate_remapping.py +++ b/fv3core/tests/savepoint/translate/translate_remapping.py @@ -130,7 +130,6 @@ def compute_from_storage(self, inputs): self.grid.area_64, inputs["nq"], inputs["pfull"], - inputs["tracers"], ) inputs.pop("nq") l_to_e_obj(**inputs) diff --git a/fv3core/tests/savepoint/translate/translate_tracer2d1l.py b/fv3core/tests/savepoint/translate/translate_tracer2d1l.py index b7856974f..9fc9d4bb5 100644 --- a/fv3core/tests/savepoint/translate/translate_tracer2d1l.py +++ b/fv3core/tests/savepoint/translate/translate_tracer2d1l.py @@ -51,8 +51,9 @@ def compute_parallel(self, inputs, communicator): self._base.make_storage_data_input_vars(inputs) all_tracers = inputs["tracers"] + tracer_count = int(inputs.pop("nq")) inputs["tracers"] = self.get_advected_tracer_dict( - inputs["tracers"], int(inputs.pop("nq")) + inputs["tracers"], tracer_count ) transport = fv3core.stencils.fvtp2d.FiniteVolumeTransport( stencil_factory=self.stencil_factory, @@ -67,7 +68,6 @@ def compute_parallel(self, inputs, communicator): transport, self.grid.grid_data, communicator, - inputs["tracers"], ) self.tracer_advection(**inputs) inputs[ @@ -89,7 +89,7 @@ def get_advected_tracer_dict(self, all_tracers, nq): units=properties["units"], ) tracer_names = utils.tracer_variables[:nq] - return {name: all_tracers[name + "_quantity"] for name in tracer_names} + return {name: all_tracers[name + "_quantity"].data for name in tracer_names} def compute_sequential(self, a, b): pytest.skip( diff --git a/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py b/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py index bc3b95f7d..e9e48a9f2 100644 --- a/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py +++ b/fv3gfs-physics/tests/savepoint/translate/translate_fv_update_phys.py @@ -174,9 +174,6 @@ def compute_parallel(self, inputs, communicator): self.namelist, communicator, self.grid.driver_grid_data, - state, - tendencies["u_dt"], - tendencies["v_dt"], ) dims_u = [pace.util.X_DIM, pace.util.Y_INTERFACE_DIM, pace.util.Z_DIM] u_quantity = self.grid.make_quantity( diff --git a/stencils/pace/stencils/c2l_ord.py b/stencils/pace/stencils/c2l_ord.py index f5962ae20..0325bb82f 100644 --- a/stencils/pace/stencils/c2l_ord.py +++ b/stencils/pace/stencils/c2l_ord.py @@ -1,6 +1,5 @@ from gt4py.gtscript import PARALLEL, computation, horizontal, interval, region -import fv3core import pace.dsl.gt4py_utils as utils from pace.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from pace.dsl.stencil import StencilFactory @@ -104,7 +103,6 @@ class CubedToLatLon: def __init__( self, - state: fv3core.DycoreState, stencil_factory: StencilFactory, grid_data: GridData, order: int, @@ -162,11 +160,7 @@ def __init__( self.u__v = WrappedHaloUpdater( comm.get_vector_halo_updater( [full_size_xyiz_halo_spec], [full_size_xiyz_halo_spec] - ), - state, - ["u"], - ["v"], - comm=comm, + ) ) def __call__( @@ -186,7 +180,7 @@ def __call__( comm: Cubed-sphere communicator """ if self._do_ord4: - self.u__v.update() + self.u__v.update([u.data], [v.data]) self._compute_cubed_to_latlon( u, v, diff --git a/stencils/pace/stencils/fv_update_phys.py b/stencils/pace/stencils/fv_update_phys.py index 9b7196ba3..6171f2000 100644 --- a/stencils/pace/stencils/fv_update_phys.py +++ b/stencils/pace/stencils/fv_update_phys.py @@ -1,7 +1,6 @@ import gt4py.gtscript as gtscript from gt4py.gtscript import FORWARD, PARALLEL, computation, exp, interval, log -import fv3core import pace.dsl.gt4py_utils as utils import pace.util import pace.util.constants as constants @@ -88,9 +87,6 @@ def __init__( namelist, comm: pace.util.CubedSphereCommunicator, grid_info: DriverGridData, - state: fv3core.DycoreState, - u_dt: pace.util.Quantity, - v_dt: pace.util.Quantity, ): orchestrate( obj=self, @@ -113,7 +109,7 @@ def __init__( stencil_factory, comm.partitioner, comm.rank, namelist, grid_info ) self._do_cubed_to_latlon = CubedToLatLon( - state, stencil_factory, grid_data, order=namelist.c2l_ord, comm=comm + stencil_factory, grid_data, order=namelist.c2l_ord, comm=comm ) self.origin = grid_indexing.origin_compute() self.extent = grid_indexing.domain_compute() @@ -127,13 +123,9 @@ def __init__( ) self._udt_halo_updater = WrappedHaloUpdater( self.comm.get_scalar_halo_updater([full_3Dfield_1pts_halo_spec]), - {"u_dt": u_dt}, - ["u_dt"], ) self._vdt_halo_updater = WrappedHaloUpdater( self.comm.get_scalar_halo_updater([full_3Dfield_1pts_halo_spec]), - {"v_dt": v_dt}, - ["v_dt"], ) # TODO: check if we actually need surface winds self._u_srf = utils.make_storage_from_shape( @@ -164,8 +156,8 @@ def __call__( dt, ) - self._udt_halo_updater.start() - self._vdt_halo_updater.start() + self._udt_halo_updater.start([u_dt]) + self._vdt_halo_updater.start([v_dt]) self._update_pressure_and_surface_winds( state.pe, state.delp, diff --git a/stencils/pace/stencils/update_atmos_state.py b/stencils/pace/stencils/update_atmos_state.py index 8fa138cf0..921ebac31 100644 --- a/stencils/pace/stencils/update_atmos_state.py +++ b/stencils/pace/stencils/update_atmos_state.py @@ -242,11 +242,8 @@ def __init__( namelist, comm: pace.util.CubedSphereCommunicator, grid_info: DriverGridData, - state: fv3core.DycoreState, - quantity_factory: pace.util.QuantityFactory, dycore_only: bool, apply_tendencies: bool, - tendency_state, ): orchestrate( obj=self, @@ -259,8 +256,6 @@ def __init__( grid_indexing = stencil_factory.grid_indexing self.namelist = namelist - origin = grid_indexing.origin_compute() - shape = grid_indexing.domain_full(add=(1, 1, 1)) self._rdt = 1.0 / Float(self.namelist.dt_atmos) self._prepare_tendencies_and_update_tracers = ( @@ -271,7 +266,6 @@ def __init__( ) ) - dims = [pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM] self._fill_GFS_delp = stencil_factory.from_origin_domain( fill_gfs_delp, origin=grid_indexing.origin_full(), @@ -284,9 +278,6 @@ def __init__( self.namelist, comm, grid_info, - state, - tendency_state.u_dt, - tendency_state.v_dt, ) self._dycore_only = dycore_only # apply_tendencies when we have run physics or fv_subgridz diff --git a/tests/main/fv3core/test_dycore_call.py b/tests/main/fv3core/test_dycore_call.py index 92bc020e2..ef7221709 100644 --- a/tests/main/fv3core/test_dycore_call.py +++ b/tests/main/fv3core/test_dycore_call.py @@ -123,7 +123,6 @@ def setup_dycore() -> Tuple[ damping_coefficients=DampingCoefficients.new_from_metric_terms(metric_terms), config=config, phis=state.phis, - state=state, ) do_adiabatic_init = False @@ -160,10 +159,10 @@ def test_temporaries_are_deterministic(): dycore1, state1, timer1 = setup_dycore() dycore2, state2, timer2 = setup_dycore() - dycore1.step_dynamics(state1, timer1) + dycore1.step_dynamics(state1, state1.tracers_as_array(), timer1) first_temporaries = copy_temporaries(dycore1, max_depth=10) assert len(first_temporaries) > 0 - dycore2.step_dynamics(state2, timer2) + dycore2.step_dynamics(state2, state2.tracers_as_array(), timer2) second_temporaries = copy_temporaries(dycore2, max_depth=10) assert_same_temporaries(second_temporaries, first_temporaries) @@ -180,14 +179,14 @@ def test_call_on_same_state_same_dycore_produces_same_temporaries(): # state_1 and state_2 are identical, if the dycore is stateless then they # should produce identical dycore final states when used to call - dycore.step_dynamics(state_1, timer_1) + dycore.step_dynamics(state_1, state_1.tracers_as_array(), timer_1) first_temporaries = copy_temporaries(dycore, max_depth=10) assert len(first_temporaries) > 0 # TODO: The orchestrated code pushed us to make the dycore stateful for halo # exchange, so we must copy into state_1 instead of using state_2. # We should call with state_2 directly when this is fixed. copy_state(state_2, state_1) - dycore.step_dynamics(state_1, timer_2) + dycore.step_dynamics(state_1, state_1.tracers_as_array(), timer_2) second_temporaries = copy_temporaries(dycore, max_depth=10) assert_same_temporaries(second_temporaries, first_temporaries) From c94b045c5a82e3ffb19410b4f76b911660b48369 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 23 Aug 2022 21:58:14 +0200 Subject: [PATCH 4/8] lint --- dsl/pace/dsl/dace/wrapped_halo_exchange.py | 2 -- .../examples/standalone/runfile/dynamics.py | 2 +- fv3core/fv3core/stencils/fv_dynamics.py | 18 +++++++++--------- .../pace/util/halo_quantity_specification.py | 5 +++-- pace-util/pace/util/halo_updater.py | 12 ++++++------ pace-util/pace/util/rotate.py | 6 ++++-- 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/dsl/pace/dsl/dace/wrapped_halo_exchange.py b/dsl/pace/dsl/dace/wrapped_halo_exchange.py index 516ef8b32..f094483c3 100644 --- a/dsl/pace/dsl/dace/wrapped_halo_exchange.py +++ b/dsl/pace/dsl/dace/wrapped_halo_exchange.py @@ -1,10 +1,8 @@ -import dataclasses from typing import List, Optional, Union import numpy as np from pace.dsl.dace.orchestration import dace_inhibitor -from pace.util.communicator import CubedSphereCommunicator from pace.util.halo_updater import HaloUpdater, VectorInterfaceHaloUpdater diff --git a/fv3core/examples/standalone/runfile/dynamics.py b/fv3core/examples/standalone/runfile/dynamics.py index 127271404..1d1353056 100755 --- a/fv3core/examples/standalone/runfile/dynamics.py +++ b/fv3core/examples/standalone/runfile/dynamics.py @@ -323,7 +323,7 @@ def setup_dycore( with timestep_timer.clock("mainloop"): if rank == 0: print(f"timestep {i+2}") - dycore.step_dynamics(state, timer=timestep_timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer=timestep_timer) times_per_step.append(timestep_timer.times) hits_per_step.append(timestep_timer.hits) timestep_timer.reset() diff --git a/fv3core/fv3core/stencils/fv_dynamics.py b/fv3core/fv3core/stencils/fv_dynamics.py index e38e74058..45a5ea9a6 100644 --- a/fv3core/fv3core/stencils/fv_dynamics.py +++ b/fv3core/fv3core/stencils/fv_dynamics.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, Optional import numpy as np from dace.frontend.python.interface import nounroll as dace_no_unroll @@ -432,9 +432,9 @@ def _compute( is_root_rank=self.comm_rank == 0, ) - for k_split in dace_no_unroll(range(state.k_split)): + for k_split in dace_no_unroll(range(state.k_split)): # type: ignore n_map = k_split + 1 - last_step = k_split == state.k_split - 1 + last_step = k_split == state.k_split - 1 # type: ignore self._dyn(state=state, tracers=tracers_dict, n_map=n_map, timer=timer) if self.grid_indexing.domain[2] > 4: @@ -463,27 +463,27 @@ def _compute( state.w, state.ua, state.va, - state.cappa, + state.cappa, # type: ignore state.q_con, state.qcld, state.pkz, state.pk, state.pe, state.phis, - state.te0_2d, + state.te0_2d, # type: ignore state.ps, - state.wsd, + state.wsd, # type: ignore state.omga, self._ak, self._bk, self._pfull, - state.dp1, + state.dp1, # type: ignore self._ptop, constants.KAPPA, constants.ZVIR, last_step, - state.consv_te, - state.bdt / state.k_split, + state.consv_te, # type: ignore + state.bdt / state.k_split, # type: ignore state.bdt, state.do_adiabatic_init, ) diff --git a/pace-util/pace/util/halo_quantity_specification.py b/pace-util/pace/util/halo_quantity_specification.py index bcb6b2de0..a93ac2a35 100644 --- a/pace-util/pace/util/halo_quantity_specification.py +++ b/pace-util/pace/util/halo_quantity_specification.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from .types import NumpyModule -from typing import Tuple, Any +from typing import Any, Tuple + from .quantity import Quantity +from .types import NumpyModule @dataclass diff --git a/pace-util/pace/util/halo_updater.py b/pace-util/pace/util/halo_updater.py index eff679f9e..4d3cee411 100644 --- a/pace-util/pace/util/halo_updater.py +++ b/pace-util/pace/util/halo_updater.py @@ -338,8 +338,8 @@ def on_c_grid(x_spec: QuantityHaloSpec, y_spec: QuantityHaloSpec): class VectorInterfaceHaloUpdater: """Exchange halo on information between ranks for data living on the interface. - This class reasons on QuantityHaloSpec for initialization and assumes the arrays given - to the start_synchronize_vector_interfaces adhere to those specs. + This class reasons on QuantityHaloSpec for initialization and assumes + the arrays given to the start_synchronize_vector_interfaces adhere to those specs. See start_synchronize_vector_interfaces for details on interface exchange. """ @@ -417,7 +417,7 @@ def _Isend_vector_shared_boundary( self._qty_x_spec.extent, ) south_data = southwest_x_view.sel( - **{ + **{ # type: ignore constants.Y_INTERFACE_DIM: 0, constants.X_DIM: slice( 0, @@ -446,7 +446,7 @@ def _Isend_vector_shared_boundary( self._qty_y_spec.extent, ) west_data = southwest_y_view.sel( - **{ + **{ # type: ignore constants.X_INTERFACE_DIM: 0, constants.Y_DIM: slice( 0, @@ -505,7 +505,7 @@ def _Irecv_vector_shared_boundary( ) north_data = northwest_x_view.sel( - **{ + **{ # type: ignore constants.Y_INTERFACE_DIM: -1, constants.X_DIM: slice( 0, @@ -526,7 +526,7 @@ def _Irecv_vector_shared_boundary( self._qty_y_spec.extent, ) east_data = southeast_y_view.sel( - **{ + **{ # type: ignore constants.X_INTERFACE_DIM: -1, constants.Y_DIM: slice( 0, diff --git a/pace-util/pace/util/rotate.py b/pace-util/pace/util/rotate.py index 0c15a60b4..dcbbd71f5 100644 --- a/pace-util/pace/util/rotate.py +++ b/pace-util/pace/util/rotate.py @@ -1,9 +1,11 @@ -from . import constants from typing import List + import numpy as np +from . import constants + -def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations) -> List[np.ndarray]: +def rotate_scalar_data(data, dims, numpy, n_clockwise_rotations) -> np.ndarray: n_clockwise_rotations = n_clockwise_rotations % 4 if n_clockwise_rotations == 0: pass From 35abf3676406fcc5767ec3bb7c791d3fddd900d6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 24 Aug 2022 16:20:09 +0200 Subject: [PATCH 5/8] Fix calls --- driver/pace/driver/driver.py | 2 -- tests/main/fv3core/test_dycore_call.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index 2e2355630..0f89e19f1 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -312,10 +312,8 @@ def exit_function(*args, **kwargs): namelist=self.config.physics_config, comm=communicator, grid_info=self.state.driver_grid_data, - quantity_factory=self.quantity_factory, dycore_only=self.config.dycore_only, apply_tendencies=self.config.apply_tendencies, - tendency_state=self.state.tendency_state, ) else: # Make sure those are set to None to raise any issues diff --git a/tests/main/fv3core/test_dycore_call.py b/tests/main/fv3core/test_dycore_call.py index ef7221709..b9d729c66 100644 --- a/tests/main/fv3core/test_dycore_call.py +++ b/tests/main/fv3core/test_dycore_call.py @@ -199,7 +199,7 @@ def error_func(*args, **kwargs): with unittest.mock.patch("gt4py.storage.storage.zeros", new=error_func): with unittest.mock.patch("gt4py.storage.storage.empty", new=error_func): - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) def test_call_does_not_define_stencils(): From cea88c01bed55f62221c78c41702cadecbe10d98 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 24 Aug 2022 16:37:41 +0200 Subject: [PATCH 6/8] Fix more tests call --- tests/main/fv3core/test_dycore_call.py | 2 +- tests/mpi/test_checkpoints.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/main/fv3core/test_dycore_call.py b/tests/main/fv3core/test_dycore_call.py index b9d729c66..e3d9f4178 100644 --- a/tests/main/fv3core/test_dycore_call.py +++ b/tests/main/fv3core/test_dycore_call.py @@ -209,4 +209,4 @@ def error_func(*args, **kwargs): raise AssertionError("call not allowed") with unittest.mock.patch("gt4py.gtscript.stencil", new=error_func): - dycore.step_dynamics(state, timer) + dycore.step_dynamics(state, state.tracers_as_array(), timer) diff --git a/tests/mpi/test_checkpoints.py b/tests/mpi/test_checkpoints.py index af306b1d6..d9675f6a5 100644 --- a/tests/mpi/test_checkpoints.py +++ b/tests/mpi/test_checkpoints.py @@ -173,7 +173,7 @@ def test_fv_dynamics( checkpointer=validation, ) with validation.trial(): - dycore.step_dynamics(state) + dycore.step_dynamics(state, state.tracers_as_array()) def _calibrate_thresholds( @@ -202,7 +202,7 @@ def _calibrate_thresholds( trial_state, _ = initializer.new_state() perturb(dycore_state_to_dict(trial_state)) with calibration.trial(): - dycore.step_dynamics(trial_state) + dycore.step_dynamics(trial_state, trial_state.tracers_as_array()) all_thresholds = communicator.comm.allgather(calibration.thresholds) thresholds = merge_thresholds(all_thresholds) set_manual_thresholds(thresholds) From cfc700edfe0fb611f9b7efb7fe43b2f5bddbdcd9 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 24 Aug 2022 17:14:18 +0200 Subject: [PATCH 7/8] Pass u_dt, v_dt, pt_dt in UpdateAtmosphereState as Qty --- driver/pace/driver/driver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index 0f89e19f1..b4ad261af 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -451,9 +451,9 @@ def _step_physics(self, timestep: float): self.end_of_step_update( dycore_state=self.state.dycore_state, phy_state=self.state.physics_state, - u_dt=self.state.tendency_state.u_dt.storage, - v_dt=self.state.tendency_state.v_dt.storage, - pt_dt=self.state.tendency_state.pt_dt.storage, + u_dt=self.state.tendency_state.u_dt, + v_dt=self.state.tendency_state.v_dt, + pt_dt=self.state.tendency_state.pt_dt, dt=float(timestep), ) From 11216ea53c2c8db9abc86b83153da644b2b3b0ef Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 24 Aug 2022 17:30:44 +0200 Subject: [PATCH 8/8] Exchange data instead of qty --- stencils/pace/stencils/fv_update_phys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stencils/pace/stencils/fv_update_phys.py b/stencils/pace/stencils/fv_update_phys.py index 6171f2000..e67a750ca 100644 --- a/stencils/pace/stencils/fv_update_phys.py +++ b/stencils/pace/stencils/fv_update_phys.py @@ -156,8 +156,8 @@ def __call__( dt, ) - self._udt_halo_updater.start([u_dt]) - self._vdt_halo_updater.start([v_dt]) + self._udt_halo_updater.start([u_dt.data]) + self._vdt_halo_updater.start([v_dt.data]) self._update_pressure_and_surface_winds( state.pe, state.delp,