Skip to content

Commit 1a7d6d0

Browse files
Gregory Robertstylerflex
authored andcommitted
add terminal modifications for autograd smatrix
1 parent 8440325 commit 1a7d6d0

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Optional, Union
66

7-
import numpy as np
7+
import autograd.numpy as np
88
import pydantic.v1 as pd
99

1010
from tidy3d.components.base import cached_property
@@ -202,7 +202,11 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
202202

203203
port_names = [port.name for port in self.ports]
204204

205-
values = np.zeros(
205+
a_values = np.zeros(
206+
(len(self.freqs), len(port_names), len(port_names)),
207+
dtype=complex,
208+
)
209+
b_values = np.zeros(
206210
(len(self.freqs), len(port_names), len(port_names)),
207211
dtype=complex,
208212
)
@@ -211,21 +215,24 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
211215
"port_out": port_names,
212216
"port_in": port_names,
213217
}
214-
a_matrix = TerminalPortDataArray(values, coords=coords)
215-
b_matrix = a_matrix.copy(deep=True)
218+
a_matrix = TerminalPortDataArray(a_values, coords=coords)
219+
b_matrix = TerminalPortDataArray(b_values, coords=coords)
216220

217221
# Tabulate the reference impedances at each port and frequency
218222
port_impedances = self._port_reference_impedances(batch_data=batch_data)
219223

220-
# loop through source ports
221-
for port_in in self.ports:
224+
for _port_in_idx, port_in in enumerate(self.ports):
222225
sim_data = batch_data[self._task_name(port=port_in)]
223226
a, b = self.compute_power_wave_amplitudes_at_each_port(port_impedances, sim_data)
224227
indexer = {"f": a.f, "port_in": port_in.name, "port_out": a.port}
225-
a_matrix.loc[indexer] = a
226-
b_matrix.loc[indexer] = b
228+
229+
a_data = np.expand_dims(a.data, axis=1)
230+
b_data = np.expand_dims(b.data, axis=1)
231+
a_matrix = a_matrix.with_updated_data(data=a_data, coords=indexer)
232+
b_matrix = b_matrix.with_updated_data(data=b_data, coords=indexer)
227233

228234
s_matrix = self.ab_to_s(a_matrix, b_matrix)
235+
229236
return s_matrix
230237

231238
@pd.validator("simulation")
@@ -315,9 +322,12 @@ def compute_power_wave_amplitudes_at_each_port(
315322

316323
for port_out in self.ports:
317324
V_out, I_out = self.compute_port_VI(port_out, sim_data)
318-
indexer = {"port": port_out.name}
319-
V_matrix.loc[indexer] = V_out
320-
I_matrix.loc[indexer] = I_out
325+
indexer = {"f": V_out.f, "port": port_out.name}
326+
327+
V_out_data = np.expand_dims(V_out.data, axis=1)
328+
I_out_data = np.expand_dims(I_out.data, axis=1)
329+
V_matrix = V_matrix.with_updated_data(data=V_out_data, coords=indexer)
330+
I_matrix = V_matrix.with_updated_data(data=I_out_data, coords=indexer)
321331

322332
V_numpy = V_matrix.values
323333
I_numpy = I_matrix.values

0 commit comments

Comments
 (0)