Skip to content

Support of randomness in autogram #444

@PierreQuinton

Description

@PierreQuinton

Here are some thoughts on how we could support randomness in autogram.

Consider this context:

class local_rng_context:
    """
    A context manager that saves the global CPU and CUDA RNG states upon entry,
    sets a provided state, runs the code, and restores the original states upon exit.
    """

    def __init__(self, cpu_state: torch.ByteTensor):
        self.cpu_state = cpu_state # Need support for cuda state too
        self.original_cpu_state: torch.ByteTensor | None = None

    def __enter__(self):
        self.original_cpu_state = torch.get_rng_state()
        torch.set_rng_state(self.cpu_state)

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.set_rng_state(self.original_cpu_state)

It can be used to

cpu_rng_state = torch.get_rng_state()
model(input)

with set_rng_state_context(cpu_rng_state):
    y_B, vjp_fn = autograd.functional.vjp(functional_call_model, x)

Then we would expect the forward phase of the vjp phase to be the same as the first one. Moreover we restore the randomness to what it was before it.

In the case of autogram, this amounts to saving the rng in a pre forward hook on the module, then in the call to vjp setting the correct rng. I'm not completely sure this would work but it seems promising.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions