-
Notifications
You must be signed in to change notification settings - Fork 13
Description
From the paper:
"we experimented with initializing the hidden states with zeros on half of the examples in the batch, and with standard Gaussian noise on the rest of the examples"
"Mixed initialization: During each training forward pass, each sample was assigned with either zero initialization (i.e. the fixed point was initialized with the 0 vector) or standard normal distribution (i.e. ...) using a Bernoulli random variable of probability 0.5 (i.e. the examples that were run with zero vs. normal initializations were roughly half-half."
Current implementation:
torchdeq/torchdeq/utils/init.py
Lines 4 to 21 in 4f6bd5f
def mixed_init(z_shape, device=None): | |
""" | |
Initializes a tensor with a shape of `z_shape` with half Gaussian random values and hald zeros. | |
Proposed in the paper, `Path Independent Equilibrium Models Can Better Exploit Test-Time Computation <https://arxiv.org/abs/2211.09961>`_, | |
for better path independence. | |
Args: | |
z_shape (tuple): Shape of the tensor to be initialized. | |
device (torch.device, optional): The desired device of returned tensor. Default None. | |
Returns: | |
torch.Tensor: A tensor of shape `z_shape` with values randomly initialized and zero masked. | |
""" | |
z_init = torch.randn(*z_shape, device=device) | |
mask = torch.zeros_like(z_init, device=device).bernoulli_(0.5) | |
return z_init * mask |
It seems more appropriate to do this instead to match the paper.
*mask_shape, _ = z_shape
mask = torch.empty(*mask_shape, device=device).bernoulli_(0.5).unsqueeze(-1)
This form has the disadvantage of assuming that all but the last dimension are batch dimensions. But this seems to be quite a reasonable assumption, and downstream users can easily adjust to this by reshaping and rearranging the dimensions.