Skip to content

0.12.0

Latest
Compare
Choose a tag to compare
@cgarciae cgarciae released this 25 Sep 23:58
· 45 commits to main since this release

Flax 0.12.0 includes many updates and some important breaking changes to the NNX API.

Breaking Changes

Pytree Strict Attributes

nnx.Pytree and therefore nnx.Module are now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:

from flax import nnx
import jax
import jax.numpy as jnp

class Foo(nnx.Module):
  def __init__(self, use_bias, rngs):
    self.layers = [  # ERROR
      nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
    ]
    self.bias = None # status = static
    if use_bias:
      self.bias = nnx.Param(rngs.params.uniform(3,)) # ERROR

This happens for two reasons:

  1. JAX pytree structures that contain Arrays now have to be marked with nnx.data. Alternatively, if the container pytree is a list or a dict, you can use nnx.List or nnx.Dict, which additionally allow mixed "data" and "static" elements.
  2. Attributes will no longer automatically change their status—this now has to be done explicitly using nnx.data or nnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.

To fix the above you can just create layers as a List Module which is automatically recognized as data, and be explicit about bias being a data attribute on the first assignment by using nnx.data:

class Foo(nnx.Module):
  def __init__(self, use_bias, rngs):
    self.layers = nnx.List([  # nnx.data also works but List is recommended
      nnx.Linear(3, 3, rngs=rngs) for _ in range(5)
    ])
    self.bias = nnx.data(None)
    if use_bias:
      self.bias = nnx.Param(rngs.params.uniform(3,))

For more information check the Module & Pytree guide.

Eager Sharding

Variables will now eagerly shard their values when sharding_names metadata is provided. A mesh is required—it can be provided either via passing a mesh metadata attribute or setting the global mesh context via jax.set_mesh. This simplifies the process of sharding a Variable to construction time:

jax.config.update('jax_num_cpu_devices', 8)
mesh = jax.make_mesh((2, 4), ('data', 'model'))

with jax.set_mesh(mesh):
  variable = nnx.Param(jnp.ones((16, 32)), sharding_names=(None, 'model'))
  
print(variable.value.sharding)

Eager sharding will also occur when using the nnx.with_partitioning initializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersome nnx.get_partition_spec + jax.lax.with_sharding_constraint + nnx.update pattern:

with jax.set_mesh(mesh):
  linear = nnx.Linear(
    in_features=16, out_features=16, use_bias=False,
    kernel_init=nnx.with_partitioning(
      nnx.initializers.lecun_normal(), (None, 'model')
    ),
    rngs=nnx.Rngs(0),
  )
  optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
  
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)

For projects that currently rely on other means for sharding, eager sharding can be turned off by passing eager_sharding=False to the Variable constructor, either directly or through initializer decorators like nnx.with_partitioning:

linear = nnx.Linear(
  in_features=16, out_features=16, use_bias=False,
  kernel_init=nnx.with_partitioning(
    nnx.initializers.lecun_normal(), (None, 'model'), eager_sharding=False
  ),
  rngs=nnx.Rngs(0),
)
optimizer = nnx.Optimizer(linear, optax.adam(1e-3), wrt=nnx.Param)
  
print(linear.kernel.value.sharding)
print(optimizer.opt_state[0].mu.kernel.value.sharding)

Eager sharding can also be turned off globally via the flax_always_shard_variable config flag or the FLAX_ALWAYS_SHARD_VARIABLE environment variable:

import flax
flax.config.update('flax_always_shard_variable', False)

For more information, check out the Variable eager sharding FLIP.

In-Place Operators No Longer Allowed

In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with Tracer semantics:

w = nnx.Variable(jnp.array(0))
w += 1  # ERROR

The fix is to simply operate on the .value property instead:

w.value += 1

All Changes

New Contributors

Full Changelog: v0.11.2...v0.12.0