Skip to content

Commit 1749017

Browse files
committed
add exception handling for reshaping with DataArray._with_updated_copy
1 parent 5d07458 commit 1749017

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

tests/test_data/test_data_arrays.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_interp(method, scalar_index):
472472
xrt.assert_allclose(xr_interp, ag_interp)
473473

474474

475-
def test_with_updated_data():
475+
def test_with_updated_data_grad():
476476
"""Check the ``DataArray.with_updated_data()`` method."""
477477

478478
arr = td.SpatialDataArray(
@@ -499,3 +499,20 @@ def f(x):
499499
assert np.all(g == np.ones_like(data))
500500

501501
check_grads(f, order=1, modes=["rev"])(data)
502+
503+
504+
def test_with_updated_data_shape():
505+
"""Check the ``DataArray.with_updated_data()`` method."""
506+
507+
arr = td.SpatialDataArray(
508+
np.ones((2, 3, 4, 5), dtype=np.complex128),
509+
coords={"x": [0, 1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "w": [0, 1, 2, 3, 4]},
510+
)
511+
512+
# wrong shape
513+
data = np.zeros((1, 1, 1, 3))
514+
515+
coords = {"x": 0, "y": 2, "z": 3}
516+
517+
with pytest.raises(ValueError):
518+
arr2 = arr._with_updated_data(data=data, coords=coords)

tidy3d/components/data/data_array.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,14 @@ def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataAr
507507
for i, dim in enumerate(self.dims):
508508
if dim in coords:
509509
new_shape[i] = 1
510-
new_data = data.reshape(new_shape)
510+
try:
511+
new_data = data.reshape(new_shape)
512+
except ValueError as e:
513+
raise ValueError(
514+
"Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was "
515+
f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this "
516+
"error please raise an issue on the tidy3d github repository with the context."
517+
) from e
511518

512519
# broadcast data to repeat data along the selected dimensions to match mask
513520
new_data = new_data + np.zeros_like(old_data)

0 commit comments

Comments
 (0)