-
Notifications
You must be signed in to change notification settings - Fork 10
Open
Description
Here are some thoughts on how we could support randomness in autogram
.
Consider this context:
class local_rng_context:
"""
A context manager that saves the global CPU and CUDA RNG states upon entry,
sets a provided state, runs the code, and restores the original states upon exit.
"""
def __init__(self, cpu_state: torch.ByteTensor):
self.cpu_state = cpu_state # Need support for cuda state too
self.original_cpu_state: torch.ByteTensor | None = None
def __enter__(self):
self.original_cpu_state = torch.get_rng_state()
torch.set_rng_state(self.cpu_state)
def __exit__(self, exc_type, exc_val, exc_tb):
torch.set_rng_state(self.original_cpu_state)
It can be used to
cpu_rng_state = torch.get_rng_state()
model(input)
with set_rng_state_context(cpu_rng_state):
y_B, vjp_fn = autograd.functional.vjp(functional_call_model, x)
Then we would expect the forward phase of the vjp
phase to be the same as the first one. Moreover we restore the randomness to what it was before it.
In the case of autogram
, this amounts to saving the rng in a pre forward hook on the module, then in the call to vjp setting the correct rng. I'm not completely sure this would work but it seems promising.
Metadata
Metadata
Assignees
Labels
No labels