Skip to content

Commit 69bbee5

Browse files
committed
better solution to with_updated_copy
1 parent 1a7d6d0 commit 69bbee5

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

tidy3d/components/data/data_array.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -489,24 +489,37 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
489489
result = result.transpose(*out_dims)
490490
return result
491491

492-
def with_updated_data(self, data, coords: dict[str, Any]) -> DataArray:
492+
def with_updated_data(self, data: xr, coords: dict[str, Any]) -> DataArray:
493493
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd-approved."""
494494

495+
old_values = self.values
496+
shape_in = data.shape
497+
shape_in_new = []
498+
i = 0
499+
for dim in self.dims:
500+
if dim in coords:
501+
shape_in_new.append(1)
502+
else:
503+
shape_in_new.append(shape_in[i])
504+
i += 1
505+
506+
data = data.reshape(shape_in_new)
507+
data = data + np.zeros_like(old_values)
508+
495509
mask = xr.zeros_like(self, dtype=bool)
496510
mask.loc[coords] = True
497511

498-
old_values = self.values
499-
replacement_data = np.zeros_like(old_values)
500-
axes_to_fill = [self.get_axis_num(key) for key in coords]
512+
modified_values = np.where(mask, data, old_values)
501513

502-
reshape_guide = [1] * len(old_values.shape)
503-
for idx, dim in enumerate(axes_to_fill):
504-
reshape_guide[dim] = data.shape[idx]
514+
return self.copy(deep=True, data=modified_values)
505515

506-
replacement_data = replacement_data + data.reshape(reshape_guide)
507-
modified_values = np.where(mask, replacement_data, old_values)
516+
def _with_updated_data(self, data: DataArray, coords: dict[str, Any]) -> DataArray:
517+
"""Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd-approved."""
508518

509-
return self.copy(deep=True, data=modified_values)
519+
mask = xr.zeros_like(self, dtype=bool)
520+
mask.loc[coords] = True
521+
522+
return xr.where(mask, data, self)
510523

511524

512525
class FreqDataArray(DataArray):

tidy3d/plugins/smatrix/component_modelers/modal.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,7 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
325325
"mode_index_out": mode_index_out,
326326
}
327327

328-
reshaped_smatrix_data = np.reshape(s_matrix_elements, (1, 1, 1, -1))
329-
s_matrix = s_matrix.with_updated_data(data=reshaped_smatrix_data, coords=coords_set)
328+
s_matrix = s_matrix.with_updated_data(data=s_matrix_elements, coords=coords_set)
330329

331330
# element can be determined by user-defined mapping
332331
for (row_in, col_in), (row_out, col_out), mult_by in self.element_mappings:
@@ -350,7 +349,6 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> ModalPortDataArr
350349
"mode_index_out": mode_index_out_to,
351350
}
352351

353-
reshaped_elements_from = np.reshape(elements_from, (1, 1, 1, -1))
354-
s_matrix = s_matrix.with_updated_data(data=reshaped_elements_from, coords=coords_to)
352+
s_matrix = s_matrix.with_updated_data(data=elements_from, coords=coords_to)
355353

356354
return s_matrix

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,9 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
224224
for _port_in_idx, port_in in enumerate(self.ports):
225225
sim_data = batch_data[self._task_name(port=port_in)]
226226
a, b = self.compute_power_wave_amplitudes_at_each_port(port_impedances, sim_data)
227-
indexer = {"f": a.f, "port_in": port_in.name, "port_out": a.port}
228-
229-
a_data = np.expand_dims(a.data, axis=1)
230-
b_data = np.expand_dims(b.data, axis=1)
231-
a_matrix = a_matrix.with_updated_data(data=a_data, coords=indexer)
232-
b_matrix = b_matrix.with_updated_data(data=b_data, coords=indexer)
227+
indexer = {"port_in": port_in.name}
228+
a_matrix = a_matrix.with_updated_data(data=a.data, coords=indexer)
229+
b_matrix = b_matrix.with_updated_data(data=b.data, coords=indexer)
233230

234231
s_matrix = self.ab_to_s(a_matrix, b_matrix)
235232

@@ -322,12 +319,9 @@ def compute_power_wave_amplitudes_at_each_port(
322319

323320
for port_out in self.ports:
324321
V_out, I_out = self.compute_port_VI(port_out, sim_data)
325-
indexer = {"f": V_out.f, "port": port_out.name}
326-
327-
V_out_data = np.expand_dims(V_out.data, axis=1)
328-
I_out_data = np.expand_dims(I_out.data, axis=1)
329-
V_matrix = V_matrix.with_updated_data(data=V_out_data, coords=indexer)
330-
I_matrix = V_matrix.with_updated_data(data=I_out_data, coords=indexer)
322+
indexer = {"port": port_out.name}
323+
V_matrix = V_matrix.with_updated_data(data=V_out.data, coords=indexer)
324+
I_matrix = V_matrix.with_updated_data(data=I_out.data, coords=indexer)
331325

332326
V_numpy = V_matrix.values
333327
I_numpy = I_matrix.values

0 commit comments

Comments
 (0)