-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Required prerequisites
- I have read the documentation https://torchopt.readthedocs.io.
- I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
- Consider asking first in a Discussion.
What version of TorchOpt are you using?
0.7.3
System information
pip install torchopt
Python 3.10.12
3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
0.7.3 2.5.0a0+872d972e41.nv24.08 2.5.0a0+872d972e41.nv24.08
Problem description
When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.
The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:pytree.tree_map
since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.
Reproducible example code
The Python snippets:
mask = lambda p: torchopt.pytree.tree_map(lambda x: x.ndim != 1, p)
optimizer = torchopt.adamw(lr=0.2, weight_decay=0.1, mask=mask)
Command lines:
python parallel_train_torchopt.py
Extra dependencies:
Steps to reproduce:
- Use example https://github.com/metaopt/torchopt/blob/main/examples/FuncTorch/parallel_train_torchopt.py#L188
- Change the optimizer to adamw and specify a mask as specified in the python snipped above.
- python parallel_train_torchopt.py
- optimizer.update fails
Traceback
File "torchopt_test.py", line 230, in <module>
functorch_original.test_train_step_fn(weights, opt_state, points, labels)
File "torchopt_test.py", line 160, in test_train_step_fn
loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
File "torchopt_test.py", line 154, in train_step_fn
updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False)
File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn
flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn
updates, new_s = fn(updates, s, params=params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn
new_masked_updates, new_inner_state = inner.update(
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn
updates = tree_map(f, params, updates)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flat
return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg]
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fn
return func(x, *xs) if x is not None else None
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in f
return g.add(p, alpha=weight_decay) if g is not None else g
AttributeError: 'MaskedNode' object has no attribute 'add'
Expected behavior
The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.
Additional context
No response