Skip to content

feat(nnx): add arrays=True to nnx.clone for independent buffer copies#5482

Open
Sumu004 wants to merge 2 commits into
google:mainfrom
Sumu004:feat/nnx-clone-arrays-deep-copy
Open

feat(nnx): add arrays=True to nnx.clone for independent buffer copies#5482
Sumu004 wants to merge 2 commits into
google:mainfrom
Sumu004:feat/nnx-clone-arrays-deep-copy

Conversation

@Sumu004
Copy link
Copy Markdown

@Sumu004 Sumu004 commented Jun 4, 2026

What does this PR do?

nnx.clone() currently uses copy-on-write semantics: new Variable wrapper objects are created but the underlying jax.Array buffers are shared with the original. This works for the common mutation-after-clone pattern, but breaks donate_argnums — JAX inspects buffer addresses before running user code and sees the same physical buffer donated twice, raising:

"the same buffer cannot be donated more than once"

This PR adds an arrays=True keyword argument (default False to preserve existing behaviour):

cloned = nnx.clone(model, arrays=True)

When arrays=True, after the standard clone a second pass replaces every jax.Array leaf in the cloned state with jnp.array(x), forcing a new physical allocation. The clone and the original are then fully independent at the buffer level.

Fixes #5461

Root cause

Variable.copy() uses jax.tree.map(lambda x: x, value) to "copy" the value. The comment says "make a copy" but the lambda is the identity — JAX reconstructs the pytree wrapper but reuses the underlying buffer. This is intentional (cheap clone, copy-on-write) but the docstring only shows the mutation-then-diverge path, not the donate_argnums failure mode.

Design

The maintainer (@samanklesaria) suggested arrays=True as the parameter name (#5461). Implementation is a post-merge second pass:

if arrays and variables:
    _, cloned_state = split(merged, graph=graph)
    deep_state = jax.tree.map(
        lambda x: jnp.array(x) if isinstance(x, jax.Array) else x,
        cloned_state)
    return merge(graphdef, deep_state, copy=False)

State is a registered JAX pytree so jax.tree.map reaches the raw arrays correctly.

Tests

  • test_clone_arrays_true_creates_independent_buffers: verifies values are equal and unsafe_buffer_pointer() differs
  • test_clone_arrays_false_default_preserves_copy_on_write: existing behaviour unchanged

All existing clone tests still pass.

Sumu004 added 2 commits June 4, 2026 17:50
nnx.clone() currently uses copy-on-write semantics: new Variable wrapper
objects are created but the underlying jax.Array buffers are shared with
the original (Variable.copy() uses jax.tree.map(lambda x: x, value) which
is structurally a copy but physically a no-op at the buffer level).

This works for the common mutation-after-clone pattern (model.bias[...] += 1
rebinds the attribute, so the two diverge), but breaks donate_argnums:
JAX inspects buffer addresses before running user code, and sees the same
physical buffer donated twice, raising:
  "the same buffer cannot be donated more than once"

Add arrays=True (default False to preserve existing behaviour):
  cloned = nnx.clone(model, arrays=True)

When arrays=True, after the standard clone a second pass replaces every
jax.Array leaf in the cloned State with jnp.array(x), forcing a new
physical allocation.  The clone and the original are then fully independent
at the buffer level, making donate_argnums safe.

Fixes: google#5461
… .value

.value access is deprecated in newest Flax; use variable[...] for
Variable[Array] instances to avoid UnexpectedException in doctest.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.clone creates buffer copies with the same IDs, causing errors with donate_argnums

1 participant