-
Notifications
You must be signed in to change notification settings - Fork 65
custom autograd hooks #2987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
custom autograd hooks #2987
Conversation
… and numerical_structures arguments to provide hooks into gradient computation for user-defined vjp calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
22 files reviewed, 3 comments
tidy3d/web/api/autograd/autograd.py
Outdated
| sim_dict[task_name] = sim | ||
|
|
||
| if user_vjp is not None: | ||
| if type(user_vjp) is not type(simulations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Use isinstance() for type checking instead of type() to correctly handle inheritance
| if type(user_vjp) is not type(simulations): | |
| if not isinstance(user_vjp, type(simulations)): |
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/autograd.py
Line: 617:617
Comment:
**syntax:** Use `isinstance()` for type checking instead of `type()` to correctly handle inheritance
```suggestion
if not isinstance(user_vjp, type(simulations)):
```
How can I resolve this? If you propose a fix, please make it concise.
tidy3d/web/api/autograd/autograd.py
Outdated
| user_vjp = user_vjp_dict | ||
|
|
||
| if numerical_structures is not None: | ||
| if type(numerical_structures) is not type(simulations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Use isinstance() for type checking instead of type() to correctly handle inheritance
| if type(numerical_structures) is not type(simulations): | |
| if not isinstance(numerical_structures, type(simulations)): |
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/web/api/autograd/autograd.py
Line: 630:630
Comment:
**syntax:** Use `isinstance()` for type checking instead of `type()` to correctly handle inheritance
```suggestion
if not isinstance(numerical_structures, type(simulations)):
```
How can I resolve this? If you propose a fix, please make it concise.| derivative_values_map.update( | ||
| med_or_geo_field._compute_derivatives(derivative_info=info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Checking vjp_fns is not None before using it, but not checking if path_key in vjp_fns. If vjp_fns is an empty dict, the code will skip custom VJP even when one exists for the path_key
Consider: if vjp_fns and path_key in vjp_fns:
Prompt To Fix With AI
This is a comment left during a code review.
Path: tidy3d/components/structure.py
Line: 390:391
Comment:
**logic:** Checking `vjp_fns is not None` before using it, but not checking if `path_key in vjp_fns`. If `vjp_fns` is an empty dict, the code will skip custom VJP even when one exists for the path_key
Consider: `if vjp_fns and path_key in vjp_fns:`
How can I resolve this? If you propose a fix, please make it concise.
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/utils.pytidy3d/components/simulation.pytidy3d/plugins/smatrix/run.pytidy3d/web/api/autograd/autograd.pytidy3d/web/api/autograd/backward.py |
marcorudolphflex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really a cool feature!
Would work on readability/formatting.
And I wonder if some helpers to construct the user_vjp/numericals are feasible. When I look at the tests, it seems like one has to wire quite a lot around with stuff which is already pretty wired? Not sure how feasible, maybe also for a follow-up PR.
I know it should serve as an internal feature for now, but still I question if the usability could be easier.
tidy3d/components/structure.py
Outdated
|
|
||
| def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: | ||
| def _compute_derivatives( | ||
| self, derivative_info: DerivativeInfo, vjp_fns=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing type, consider add docstring update
tidy3d/plugins/smatrix/run.py
Outdated
| def _run_local( | ||
| modeler: ComponentModelerType, | ||
| path_dir: str = DEFAULT_DATA_DIR, | ||
| numerical_structures=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing types, extend docstring
| } | ||
| derivative_info_custom_medium = derivative_info.updated_copy(**update_kwargs) | ||
|
|
||
| def finite_difference_gradient(perturb_up, perturb_down, derivative_info_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arguments unused - does not matter in this case but could be confusing after changes
| frequencies: ArrayLike | ||
| """Frequencies at which the adjoint gradient should be computed.""" | ||
|
|
||
| updated_epsilon: Callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm leaning towards this being required since it's made for every structure that we call _compute_derivatives on
| >>> b = Sphere(center=(1,2,3), radius=2) | ||
| """ | ||
|
|
||
| radius: TracedSize1D = pydantic.Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loses non-negativity constraint, consider adding validator
tidy3d/web/api/autograd/autograd.py
Outdated
| pay_type: typing.Union[PayType, str] = PayType.AUTO, | ||
| priority: typing.Optional[int] = None, | ||
| lazy: typing.Optional[bool] = None, | ||
| numerical_structures: typing.Optional[dict[int, dict[str, typing.Any]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update docstring
| pay_type: typing.Union[PayType, str] = PayType.AUTO, | ||
| priority: typing.Optional[int] = None, | ||
| lazy: typing.Optional[bool] = None, | ||
| numerical_structures: typing.Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing and handling is really hard to read/interpret... don't we want to go with class instances here (and throughout the code until some point)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very good point - see above, but made these contain class instances in the new version
| size=(dim_um, dim_um, thickness_um), | ||
| ) | ||
|
|
||
| eval_fns, eval_fn_names = make_eval_fns() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused?
| """Test a variety of autograd permittivity gradients for DiffractionData by""" | ||
| """comparing them to numerical finite difference.""" | ||
|
|
||
| test_number = test_parameters["test_number"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
| params_up[param_idx] += step_size | ||
| params_down[param_idx] -= step_size | ||
|
|
||
| rin_up = create_ring(params_up) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
…r vjp interface and updated unit tests (not numerical tests yet)
This adds custom user hooks for plugging into autograd. They are not exposed externally, but meant to be used/tested privately for now to help implement new/custom features and integrate with things like photonforge.
There are two hooks. The first is user_vjp which allows one to override the vjp calculation for a given structure that we already have in our library. For example, the sphere gradient could be computed with permittivity finite differences using this method. The second hook is numerical_structures which allows aa different custom integration. It allows the specification of a function that creates a tidy3d structure based on input parameters and then provides a vjp for those parameters. This is useful for cases where the adjoint fields are needed for that vjp. For example, you can implement a ring structure with a trimesh that depends on the inner and outer radius. This structure is inserted into the simulation and then the adjoint fields are available inside the vjp to write a numerical or other type of gradient calculation.
Greptile Overview
Greptile Summary
This PR introduces custom autograd hooks (
user_vjpandnumerical_structures) for gradient computation in electromagnetic simulations. The implementation adds two new internal functionsrun_customandrun_async_customthat wrap the existing public API.Key Changes:
user_vjpparameter allows overriding VJP calculations for existing structure geometries (e.g., computing sphere gradients via permittivity finite differences)numerical_structuresparameter enables dynamic structure insertion with custom gradient functions, useful when adjoint fields are needed for gradient computation (e.g., trimesh-based ring structures)_compute_derivativesandpostprocess_adjrunandrun_asyncfunctions remain unchanged; new functionality accessed viarun_customandrun_async_customIssues Found:
type()instead ofisinstance()in two locations (lines 617, 630 of autograd.py)vjp_fnsmay not correctly trigger custom VJPTesting:
Comprehensive numerical tests compare custom VJP/numerical structure gradients against finite differences for both single and batch runs.
Confidence Score: 3/5
type()instead ofisinstance()that violate coding standards and could cause inheritance issues. There's also a potential logic bug in the custom VJP dispatch that needs verification. The implementation is complex but well-structured, and the hooks are appropriately kept internal for now.tidy3d/web/api/autograd/autograd.py(type checking violations) andtidy3d/components/structure.py(potential VJP dispatch logic issue)Important Files Changed
File Analysis
run_customandrun_async_customfunctions. Contains type checking issues usingtype()instead ofisinstance(). Complex flow for handling numerical structures and user VJP hooks.updated_epsilonfunction provides permittivity computation for finite differences.NumericalStructureInfodataclass andCustomVJPPathTypeto existing autograd types. Clean addition to type system._compute_derivativesto accept optionalvjp_fnsparameter. Has potential logic issue where empty dict forvjp_fnsmay not trigger custom VJP correctly.Sequence Diagram
sequenceDiagram participant User participant run_custom participant setup_run participant _run_primitive participant run_webapi participant setup_adj participant postprocess_adj participant Structure User->>run_custom: simulation, numerical_structures, user_vjp alt has numerical_structures or traced fields run_custom->>setup_run: prepare simulation with hooks setup_run->>setup_run: insert numerical structures setup_run->>setup_run: extract traced fields setup_run-->>run_custom: sim_fields, prepared_sim, numerical_info run_custom->>_run_primitive: execute forward simulation _run_primitive->>run_webapi: run forward with aux_data run_webapi-->>_run_primitive: sim_data_fwd _run_primitive-->>run_custom: traced_fields_data Note over User,Structure: Backward Pass (when gradient needed) User->>setup_adj: request gradients with vjp seeds setup_adj->>setup_adj: create adjoint simulations setup_adj->>run_webapi: run adjoint simulations run_webapi-->>setup_adj: sim_data_adj setup_adj-->>User: adjoint data User->>postprocess_adj: compute VJPs from adjoint data alt user_vjp provided postprocess_adj->>Structure: _compute_derivatives(vjp_fns) Structure->>Structure: call user-provided VJP function Structure-->>postprocess_adj: custom gradients else numerical_structures postprocess_adj->>postprocess_adj: call numerical VJP function postprocess_adj->>postprocess_adj: compute finite difference else default postprocess_adj->>Structure: _compute_derivatives() Structure->>Structure: compute analytical gradients Structure-->>postprocess_adj: gradients end postprocess_adj-->>User: final VJP values else no traced fields or numerical structures run_custom->>run_webapi: regular simulation run run_webapi-->>User: simulation data endContext used:
dashboard- Use isinstance() for type checking instead of type() to correctly handle inheritance. (source)