4
4
5
5
from typing import Optional , Union
6
6
7
- import numpy as np
7
+ import autograd . numpy as np
8
8
import pydantic .v1 as pd
9
9
10
10
from tidy3d .components .base import cached_property
@@ -202,7 +202,11 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
202
202
203
203
port_names = [port .name for port in self .ports ]
204
204
205
- values = np .zeros (
205
+ a_values = np .zeros (
206
+ (len (self .freqs ), len (port_names ), len (port_names )),
207
+ dtype = complex ,
208
+ )
209
+ b_values = np .zeros (
206
210
(len (self .freqs ), len (port_names ), len (port_names )),
207
211
dtype = complex ,
208
212
)
@@ -211,21 +215,24 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData
211
215
"port_out" : port_names ,
212
216
"port_in" : port_names ,
213
217
}
214
- a_matrix = TerminalPortDataArray (values , coords = coords )
215
- b_matrix = a_matrix . copy ( deep = True )
218
+ a_matrix = TerminalPortDataArray (a_values , coords = coords )
219
+ b_matrix = TerminalPortDataArray ( b_values , coords = coords )
216
220
217
221
# Tabulate the reference impedances at each port and frequency
218
222
port_impedances = self ._port_reference_impedances (batch_data = batch_data )
219
223
220
- # loop through source ports
221
- for port_in in self .ports :
224
+ for _port_in_idx , port_in in enumerate (self .ports ):
222
225
sim_data = batch_data [self ._task_name (port = port_in )]
223
226
a , b = self .compute_power_wave_amplitudes_at_each_port (port_impedances , sim_data )
224
227
indexer = {"f" : a .f , "port_in" : port_in .name , "port_out" : a .port }
225
- a_matrix .loc [indexer ] = a
226
- b_matrix .loc [indexer ] = b
228
+
229
+ a_data = np .expand_dims (a .data , axis = 1 )
230
+ b_data = np .expand_dims (b .data , axis = 1 )
231
+ a_matrix = a_matrix .with_updated_data (data = a_data , coords = indexer )
232
+ b_matrix = b_matrix .with_updated_data (data = b_data , coords = indexer )
227
233
228
234
s_matrix = self .ab_to_s (a_matrix , b_matrix )
235
+
229
236
return s_matrix
230
237
231
238
@pd .validator ("simulation" )
@@ -315,9 +322,12 @@ def compute_power_wave_amplitudes_at_each_port(
315
322
316
323
for port_out in self .ports :
317
324
V_out , I_out = self .compute_port_VI (port_out , sim_data )
318
- indexer = {"port" : port_out .name }
319
- V_matrix .loc [indexer ] = V_out
320
- I_matrix .loc [indexer ] = I_out
325
+ indexer = {"f" : V_out .f , "port" : port_out .name }
326
+
327
+ V_out_data = np .expand_dims (V_out .data , axis = 1 )
328
+ I_out_data = np .expand_dims (I_out .data , axis = 1 )
329
+ V_matrix = V_matrix .with_updated_data (data = V_out_data , coords = indexer )
330
+ I_matrix = V_matrix .with_updated_data (data = I_out_data , coords = indexer )
321
331
322
332
V_numpy = V_matrix .values
323
333
I_numpy = I_matrix .values
0 commit comments