Skip to content

Jax throws error when tracing auto guide after passing it to get_model_relations #2062

@mochar

Description

@mochar

Bug Description

I make an auto guide with AutoNormal and pass it to get_model_relations. Next I attempt to create a trace from that guide. Jax then throws an error complaining about side effects:

Error + traceback
---------------------------------------------------------------------------
UnexpectedTracerError                     Traceback (most recent call last)
Cell In[10], line 1
----> 1 handlers.trace(handlers.seed(guide, 0)).get_trace()

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
    183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
    184     """
    185     Run the wrapped callable and return the recorded trace.
    186 
   (...)    189     :return: `OrderedDict` containing the execution trace.
    190     """
--> 191     self(*args, **kwargs)
    192     return self.trace

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:846, in seed.__call__(self, *args, **kwargs)
    842     cloned_seeded_fn = seed(
    843         self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
    844     )
    845     cloned_seeded_fn.stateful = True
--> 846     return cloned_seeded_fn.__call__(*args, **kwargs)
    847 return super().__call__(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:847, in seed.__call__(self, *args, **kwargs)
    845     cloned_seeded_fn.stateful = True
    846     return cloned_seeded_fn.__call__(*args, **kwargs)
--> 847 return super().__call__(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
    119     return self
    120 with self:
--> 121     return self.fn(*args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:440, in AutoNormal.__call__(self, *args, **kwargs)
    435 site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
    436 if site["fn"].support is constraints.real or (
    437     isinstance(site["fn"].support, constraints.independent)
    438     and site["fn"].support.base_constraint is constraints.real
    439 ):
--> 440     result[name] = numpyro.sample(name, site_fn)
    441 else:
    442     with helpful_support_errors(site):

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:250, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    235 initial_msg = {
    236     "type": "sample",
    237     "name": name,
   (...)    246     "infer": {} if infer is None else infer,
    247 }
    249 # ...and use apply_stack to send it to the Messengers
--> 250 msg = apply_stack(initial_msg)
    251 return msg["value"]

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:61, in apply_stack(msg)
     58     if msg.get("stop"):
     59         break
---> 61 default_process_message(msg)
     63 # A Messenger that sets msg["stop"] == True also prevents application
     64 # of postprocess_message by Messengers above it on the stack
     65 # via the pointer variable from the process_message loop
     66 for handler in _PYRO_STACK[-pointer - 1 :]:

File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:32, in default_process_message(msg)
     30 if msg["value"] is None:
     31     if msg["type"] == "sample":
---> 32         msg["value"], msg["intermediates"] = msg["fn"](
     33             *msg["args"], sample_intermediates=True, **msg["kwargs"]
     34         )
     35     else:
     36         msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:393, in Distribution.__call__(self, *args, **kwargs)
    391 sample_intermediates = kwargs.pop("sample_intermediates", False)
    392 if sample_intermediates:
--> 393     return self.sample_with_intermediates(key, *args, **kwargs)
    394 return self.sample(key, *args, **kwargs)

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:351, in Distribution.sample_with_intermediates(self, key, sample_shape)
    341 def sample_with_intermediates(self, key, sample_shape=()):
    342     """
    343     Same as ``sample`` except that any intermediate computations are
    344     returned (useful for `TransformedDistribution`).
   (...)    349     :rtype: numpy.ndarray
    350     """
--> 351     return self.sample(key, sample_shape=sample_shape), []

File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/continuous.py:2198, in Normal.sample(self, key, sample_shape)
   2194 assert is_prng_key(key)
   2195 eps = random.normal(
   2196     key, shape=sample_shape + self.batch_shape + self.event_shape
   2197 )
-> 2198 return self.loc + eps * self.scale

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:1083, in _forward_operator_to_aval.<locals>.op(self, *args)
   1082 def op(self, *args):
-> 1083   return getattr(self.aval, f"_{name}")(self, *args)

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:583, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    581 args = (other, self) if swap else (self, other)
    582 if isinstance(other, _accepted_binop_types):
--> 583   return binary_op(*args)
    584 # Note: don't use isinstance here, because we don't want to raise for
    585 # subclasses, e.g. NamedTuple objects that may override operators.
    586 if type(other) in _rejected_binop_types:

File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/ufunc_api.py:182, in ufunc.__call__(self, out, where, *args)
    180   raise NotImplementedError(f"where argument of {self}")
    181 call = self.__static_props['call'] or self._call_vectorized
--> 182 return call(*args)

    [... skipping hidden 3 frame]

File ~/.venv/lib/python3.13/site-packages/jax/_src/core.py:1053, in check_eval_args(args)
   1051 for arg in args:
   1052   if isinstance(arg, Tracer):
-> 1053     raise escaped_tracer_error(arg)

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was get_trace at /home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:307 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:160:16 (AutoGuide._setup_prototype)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:750:40 (initialize_model)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:472:46 (find_valid_initial_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:465:52 (find_valid_initial_params.<locals>._find_valid_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.UnexpectedTracerError

I do not encounter this problem when making my own guide instead of using AutoGuide.

Steps to Reproduce

import numpyro
from numpyro import distributions as dist
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.inspect import get_model_relations

def model():
    numpyro.sample('a', dist.Normal())
    
guide = AutoNormal(model)
relations = get_model_relations(guide)
handlers.trace(handlers.seed(guide, 0)).get_trace()

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions