Skip to content

[BUG] 'AttributeError: 'MaskedNode' object has no attribute 'add'' error when specifying 'mask' parameter for functional adamw API #232

@gkennickell

Description

@gkennickell

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

pip install torchopt
Python 3.10.12
3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
0.7.3 2.5.0a0+872d972e41.nv24.08 2.5.0a0+872d972e41.nv24.08

Problem description

When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.

The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:pytree.tree_map since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.

Reproducible example code

The Python snippets:

    mask = lambda p: torchopt.pytree.tree_map(lambda x: x.ndim != 1, p)
    optimizer = torchopt.adamw(lr=0.2, weight_decay=0.1, mask=mask)

Command lines:

python parallel_train_torchopt.py

Extra dependencies:


Steps to reproduce:

  1. Use example https://github.com/metaopt/torchopt/blob/main/examples/FuncTorch/parallel_train_torchopt.py#L188
  2. Change the optimizer to adamw and specify a mask as specified in the python snipped above.
  3. python parallel_train_torchopt.py
  4. optimizer.update fails

Traceback

File "torchopt_test.py", line 230, in <module>
    functorch_original.test_train_step_fn(weights, opt_state, points, labels)
  File "torchopt_test.py", line 160, in test_train_step_fn
    loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
  File "torchopt_test.py", line 154, in train_step_fn
    updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn
    flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn
    updates, new_s = fn(updates, s, params=params, inplace=inplace)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn
    new_masked_updates, new_inner_state = inner.update(
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn
    updates = tree_map(f, params, updates)
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flat
    return flat_arg.__class__(map(fn, flat_arg, *flat_args))  # type: ignore[call-arg]
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fn
    return func(x, *xs) if x is not None else None
  File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in f
    return g.add(p, alpha=weight_decay) if g is not None else g
AttributeError: 'MaskedNode' object has no attribute 'add'

Expected behavior

The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.

Additional context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions