@@ -658,6 +658,9 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = True) -> None:
658658 args = [("polyslab" , "mode" )]
659659
660660
661+ ASYNC_TEST_ARGS = args [:2 ]
662+
663+
661664def get_functions (structure_key : str , monitor_key : str ) -> dict [str , typing .Callable ]:
662665 if structure_key == ALL_KEY :
663666 structure_keys = structure_keys_
@@ -856,50 +859,55 @@ def objective(*args):
856859 assert anp .all (grad != 0.0 ), "some gradients are 0"
857860
858861
859- @pytest .mark .slow
860- def _run_autograd_async (use_emulated_run , fn_dicts , use_task_names = True ):
861- """Shared core logic for the autograd async objective test (all combos in one batch)."""
862+ def _compare_async_vs_sync (fn_dicts ) -> None :
863+ """Compare async vs non-async autograd for a subset of structure/monitor pairs."""
862864
863- def objective (* params ):
864- """Compute combined objective across all structure/monitor combinations."""
865+ # synchronous objective: run() one sim after another
866+ def objective_sync (* params ):
867+ total = 0.0
868+ for i , fn_dict in enumerate (fn_dicts ):
869+ sim = fn_dict ["sim" ](* params )
870+ data = run (sim , task_name = f"autograd_sync_{ i } " , verbose = False )
871+ total += fn_dict ["postprocess" ](data )
872+ return total
873+
874+ def objective_async (* params ):
865875 sims = {}
866876 for i , fn_dict in enumerate (fn_dicts ):
867- if use_task_names :
868- # simulate different tasks per (structure, monitor)
869- sims [f"sim_{ i } _fwd" ] = fn_dict ["sim" ](* params )
870- else :
871- sims [f"sim_{ i } " ] = fn_dict ["sim" ](* params )
877+ sim = fn_dict ["sim" ](* params )
878+ key = f"autograd_{ i } "
879+ sims [key ] = sim
872880
873- # run all simulations asynchronously in one go
874- batch_data = run_async (sims , verbose = False )
881+ batch_data = run_async (sims , verbose = False , local_gradient = False )
875882
876- # accumulate total objective value
877- total_value = 0.0
883+ total = 0.0
878884 for i , fn_dict in enumerate (fn_dicts ):
879- key = f"sim_{ i } _fwd" if use_task_names else f"sim_{ i } "
880- total_value += fn_dict ["postprocess" ](batch_data [key ])
885+ key = f"autograd_{ i } "
886+ total += fn_dict ["postprocess" ](batch_data [key ])
887+ return total
881888
882- return total_value
889+ val_sync , grad_sync = ag .value_and_grad (objective_sync )(params0 )
890+ val_async , grad_async = ag .value_and_grad (objective_async )(params0 )
883891
884- val , grad = ag .value_and_grad (objective )(params0 )
885- print (f"use_task_names={ use_task_names } " , val , grad )
886- assert anp .all (grad != 0.0 ), "some gradients are 0"
892+ val_sync = float (val_sync )
893+ val_async = float (val_async )
894+ grad_sync = np .asarray (grad_sync )
895+ grad_async = np .asarray (grad_async )
887896
897+ np .testing .assert_allclose (val_async , val_sync , rtol = 1e-8 , atol = 1e-10 )
898+ np .testing .assert_allclose (grad_async , grad_sync , rtol = 1e-6 , atol = 1e-8 )
888899
889- @pytest .mark .slow
890- def test_autograd_async_all (use_emulated_run ):
891- """Run async autograd objective for all structure/monitor pairs in one combined batch."""
892- # collect all combinations once
893- fn_dicts = [get_functions (structure_key , monitor_key ) for structure_key , monitor_key in args ]
894900
895- # call the unified helper
896- _run_autograd_async (use_emulated_run , fn_dicts , use_task_names = True )
901+ @pytest .mark .slow
902+ def test_autograd_async (use_emulated_run ):
903+ """Async autograd for a small subset; must match non-async autograd."""
897904
905+ # only use two structure/monitor combinations to keep this test cheap
906+ fn_dicts = [
907+ get_functions (structure_key , monitor_key ) for structure_key , monitor_key in ASYNC_TEST_ARGS
908+ ]
898909
899- def test_autograd_async_without_taskname (use_emulated_run ):
900- """Test autograd objective varying use_task_names only."""
901- fn_dict = get_functions (* args [0 ])
902- _run_autograd_async (use_emulated_run , [fn_dict , fn_dict ], use_task_names = False )
910+ _compare_async_vs_sync (fn_dicts )
903911
904912
905913class TestTupleGrads :
@@ -1021,6 +1029,7 @@ def objective(*args):
10211029 grad = ag .grad (objective )(params0 )
10221030
10231031
1032+ @pytest .mark .perf
10241033def test_autograd_speed_num_structures (use_emulated_run ):
10251034 """Test an objective function through tidy3d autograd."""
10261035
@@ -1144,31 +1153,14 @@ def objective(*args):
11441153
11451154@pytest .mark .slow
11461155def test_autograd_async_server (use_emulated_run ):
1147- """Run async autograd objective across all structure/monitor combos in one batch."""
1148-
1149- # gather all structure/monitor function pairs
1150- fn_dicts = [get_functions (structure_key , monitor_key ) for structure_key , monitor_key in args ]
1151-
1152- def objective (* params ):
1153- """Combined async objective for all cases."""
1154- sims = {}
1155- for i , fn_dict in enumerate (fn_dicts ):
1156- sim = fn_dict ["sim" ](* params )
1157- # mimic two tasks per combination (as original test did)
1158- sims [f"autograd_{ i } " ] = sim
1159-
1160- # run everything asynchronously in one go
1161- batch_data = run_async (sims , verbose = False , local_gradient = False )
1156+ """Same comparison, but with alternative task-keying (server-style)."""
11621157
1163- # sum the results from all simulations
1164- total_value = 0.0
1165- for i , fn_dict in enumerate (fn_dicts ):
1166- total_value += fn_dict ["postprocess" ](batch_data [f"autograd_{ i } " ])
1167- return total_value
1158+ fn_dicts = [
1159+ get_functions (structure_key , monitor_key ) for structure_key , monitor_key in ASYNC_TEST_ARGS
1160+ ]
11681161
1169- # compute value and gradient once for the entire batch
1170- val , grad = ag .value_and_grad (objective )(params0 )
1171- assert np .all (np .abs (grad ) > 0 ), "some gradients are 0"
1162+ # here we exercise the alternative key style in the async dict
1163+ _compare_async_vs_sync (fn_dicts )
11721164
11731165
11741166@pytest .mark .parametrize ("structure_key" , ("custom_med" ,))
0 commit comments