-
Notifications
You must be signed in to change notification settings - Fork 268
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
I make an auto guide with AutoNormal
and pass it to get_model_relations
. Next I attempt to create a trace from that guide. Jax then throws an error complaining about side effects:
Error + traceback
---------------------------------------------------------------------------
UnexpectedTracerError Traceback (most recent call last)
Cell In[10], line 1
----> 1 handlers.trace(handlers.seed(guide, 0)).get_trace()
File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:191, in trace.get_trace(self, *args, **kwargs)
183 def get_trace(self, *args, **kwargs) -> OrderedDict[str, Message]:
184 """
185 Run the wrapped callable and return the recorded trace.
186
(...) 189 :return: `OrderedDict` containing the execution trace.
190 """
--> 191 self(*args, **kwargs)
192 return self.trace
File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
119 return self
120 with self:
--> 121 return self.fn(*args, **kwargs)
File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:846, in seed.__call__(self, *args, **kwargs)
842 cloned_seeded_fn = seed(
843 self.fn, rng_seed=self.rng_key, hide_types=self.hide_types
844 )
845 cloned_seeded_fn.stateful = True
--> 846 return cloned_seeded_fn.__call__(*args, **kwargs)
847 return super().__call__(*args, **kwargs)
File ~/.venv/lib/python3.13/site-packages/numpyro/handlers.py:847, in seed.__call__(self, *args, **kwargs)
845 cloned_seeded_fn.stateful = True
846 return cloned_seeded_fn.__call__(*args, **kwargs)
--> 847 return super().__call__(*args, **kwargs)
File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:121, in Messenger.__call__(self, *args, **kwargs)
119 return self
120 with self:
--> 121 return self.fn(*args, **kwargs)
File ~/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:440, in AutoNormal.__call__(self, *args, **kwargs)
435 site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
436 if site["fn"].support is constraints.real or (
437 isinstance(site["fn"].support, constraints.independent)
438 and site["fn"].support.base_constraint is constraints.real
439 ):
--> 440 result[name] = numpyro.sample(name, site_fn)
441 else:
442 with helpful_support_errors(site):
File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:250, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
235 initial_msg = {
236 "type": "sample",
237 "name": name,
(...) 246 "infer": {} if infer is None else infer,
247 }
249 # ...and use apply_stack to send it to the Messengers
--> 250 msg = apply_stack(initial_msg)
251 return msg["value"]
File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:61, in apply_stack(msg)
58 if msg.get("stop"):
59 break
---> 61 default_process_message(msg)
63 # A Messenger that sets msg["stop"] == True also prevents application
64 # of postprocess_message by Messengers above it on the stack
65 # via the pointer variable from the process_message loop
66 for handler in _PYRO_STACK[-pointer - 1 :]:
File ~/.venv/lib/python3.13/site-packages/numpyro/primitives.py:32, in default_process_message(msg)
30 if msg["value"] is None:
31 if msg["type"] == "sample":
---> 32 msg["value"], msg["intermediates"] = msg["fn"](
33 *msg["args"], sample_intermediates=True, **msg["kwargs"]
34 )
35 else:
36 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:393, in Distribution.__call__(self, *args, **kwargs)
391 sample_intermediates = kwargs.pop("sample_intermediates", False)
392 if sample_intermediates:
--> 393 return self.sample_with_intermediates(key, *args, **kwargs)
394 return self.sample(key, *args, **kwargs)
File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/distribution.py:351, in Distribution.sample_with_intermediates(self, key, sample_shape)
341 def sample_with_intermediates(self, key, sample_shape=()):
342 """
343 Same as ``sample`` except that any intermediate computations are
344 returned (useful for `TransformedDistribution`).
(...) 349 :rtype: numpy.ndarray
350 """
--> 351 return self.sample(key, sample_shape=sample_shape), []
File ~/.venv/lib/python3.13/site-packages/numpyro/distributions/continuous.py:2198, in Normal.sample(self, key, sample_shape)
2194 assert is_prng_key(key)
2195 eps = random.normal(
2196 key, shape=sample_shape + self.batch_shape + self.event_shape
2197 )
-> 2198 return self.loc + eps * self.scale
File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:1083, in _forward_operator_to_aval.<locals>.op(self, *args)
1082 def op(self, *args):
-> 1083 return getattr(self.aval, f"_{name}")(self, *args)
File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/array_methods.py:583, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
581 args = (other, self) if swap else (self, other)
582 if isinstance(other, _accepted_binop_types):
--> 583 return binary_op(*args)
584 # Note: don't use isinstance here, because we don't want to raise for
585 # subclasses, e.g. NamedTuple objects that may override operators.
586 if type(other) in _rejected_binop_types:
File ~/.venv/lib/python3.13/site-packages/jax/_src/numpy/ufunc_api.py:182, in ufunc.__call__(self, out, where, *args)
180 raise NotImplementedError(f"where argument of {self}")
181 call = self.__static_props['call'] or self._call_vectorized
--> 182 return call(*args)
[... skipping hidden 3 frame]
File ~/.venv/lib/python3.13/site-packages/jax/_src/core.py:1053, in check_eval_args(args)
1051 for arg in args:
1052 if isinstance(arg, Tracer):
-> 1053 raise escaped_tracer_error(arg)
UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was get_trace at /home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/inspect.py:307 traced for jit.
------------------------------
The leaked intermediate value was created on line /home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py:160:16 (AutoGuide._setup_prototype)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:750:40 (initialize_model)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:472:46 (find_valid_initial_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/infer/util.py:465:52 (find_valid_initial_params.<locals>._find_valid_params)
/home/mochar/.venv/lib/python3.13/site-packages/numpyro/util.py:141:15 (while_loop)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.UnexpectedTracerError
I do not encounter this problem when making my own guide instead of using AutoGuide.
Steps to Reproduce
import numpyro
from numpyro import distributions as dist
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.inspect import get_model_relations
def model():
numpyro.sample('a', dist.Normal())
guide = AutoNormal(model)
relations = get_model_relations(guide)
handlers.trace(handlers.seed(guide, 0)).get_trace()
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working