feat(nnx): add arrays=True to nnx.clone for independent buffer copies#5482
Open
Sumu004 wants to merge 2 commits into
Open
feat(nnx): add arrays=True to nnx.clone for independent buffer copies#5482Sumu004 wants to merge 2 commits into
Sumu004 wants to merge 2 commits into
Conversation
nnx.clone() currently uses copy-on-write semantics: new Variable wrapper objects are created but the underlying jax.Array buffers are shared with the original (Variable.copy() uses jax.tree.map(lambda x: x, value) which is structurally a copy but physically a no-op at the buffer level). This works for the common mutation-after-clone pattern (model.bias[...] += 1 rebinds the attribute, so the two diverge), but breaks donate_argnums: JAX inspects buffer addresses before running user code, and sees the same physical buffer donated twice, raising: "the same buffer cannot be donated more than once" Add arrays=True (default False to preserve existing behaviour): cloned = nnx.clone(model, arrays=True) When arrays=True, after the standard clone a second pass replaces every jax.Array leaf in the cloned State with jnp.array(x), forcing a new physical allocation. The clone and the original are then fully independent at the buffer level, making donate_argnums safe. Fixes: google#5461
… .value .value access is deprecated in newest Flax; use variable[...] for Variable[Array] instances to avoid UnexpectedException in doctest.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
nnx.clone()currently uses copy-on-write semantics: newVariablewrapper objects are created but the underlyingjax.Arraybuffers are shared with the original. This works for the common mutation-after-clone pattern, but breaksdonate_argnums— JAX inspects buffer addresses before running user code and sees the same physical buffer donated twice, raising:This PR adds an
arrays=Truekeyword argument (defaultFalseto preserve existing behaviour):When
arrays=True, after the standard clone a second pass replaces everyjax.Arrayleaf in the cloned state withjnp.array(x), forcing a new physical allocation. The clone and the original are then fully independent at the buffer level.Fixes #5461
Root cause
Variable.copy()usesjax.tree.map(lambda x: x, value)to "copy" the value. The comment says "make a copy" but the lambda is the identity — JAX reconstructs the pytree wrapper but reuses the underlying buffer. This is intentional (cheap clone, copy-on-write) but the docstring only shows the mutation-then-diverge path, not thedonate_argnumsfailure mode.Design
The maintainer (@samanklesaria) suggested
arrays=Trueas the parameter name (#5461). Implementation is a post-merge second pass:Stateis a registered JAX pytree sojax.tree.mapreaches the raw arrays correctly.Tests
test_clone_arrays_true_creates_independent_buffers: verifies values are equal andunsafe_buffer_pointer()differstest_clone_arrays_false_default_preserves_copy_on_write: existing behaviour unchangedAll existing clone tests still pass.