0.12.0
#4984
Replies: 1 comment 4 replies
-
|
@cgarciae I'm super interested in |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Flax
0.12.0includes many updates and some important breaking changes to the NNX API.Breaking Changes
Pytree Strict Attributes
nnx.Pytreeand thereforennx.Moduleare now stricter with regards to attributes that contain Arrays and changing the status of attributes. For example, the code below now fails:This happens for two reasons:
nnx.data. Alternatively, if the container pytree is alistor adict, you can usennx.Listornnx.Dict, which additionally allow mixed "data" and "static" elements.nnx.dataornnx.static. Additionally, assigning Arrays or structures with Arrays to static attributes is now an error, as they will not automatically change to data.To fix the above you can just create
layersas aListModule which is automatically recognized as data, and be explicit aboutbiasbeing a data attribute on the first assignment by usingnnx.data:For more information check the Module & Pytree guide.
Eager Sharding
Variables will now eagerly shard their values when
sharding_namesmetadata is provided. A mesh is required—it can be provided either via passing ameshmetadata attribute or setting the global mesh context viajax.set_mesh. This simplifies the process of sharding a Variable to construction time:Eager sharding will also occur when using the
nnx.with_partitioninginitializer decorator and will automatically extend to the Optimizer. This means that both model and optimizer will be sharded at construction without the need for the somewhat cumbersomennx.get_partition_spec+jax.lax.with_sharding_constraint+nnx.updatepattern:For projects that currently rely on other means for sharding, eager sharding can be turned off by passing
eager_sharding=Falseto the Variable constructor, either directly or through initializer decorators likennx.with_partitioning:Eager sharding can also be turned off globally via the
flax_always_shard_variableconfig flag or theFLAX_ALWAYS_SHARD_VARIABLEenvironment variable:For more information, check out the Variable eager sharding FLIP.
In-Place Operators No Longer Allowed
In-place operators will now raise an error. This is done as part of the push for Variables to be compatible with
Tracersemantics:The fix is to simply operate on the
.valueproperty instead:All Changes
whereargument ofjax.numpyreductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. by @copybara-service[bot] in Avoid passing non-boolean mask towhereargument ofjax.numpyreductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. #4923flax.config.temp_flip_flagby @IvyZX in Correctly exposeflax.config.temp_flip_flag#4969New Contributors
Full Changelog: v0.11.2...v0.12.0
This discussion was created from the release 0.12.0.
Beta Was this translation helpful? Give feedback.
All reactions