The mask in https://github.com/hkproj/pytorch-llama/blob/067f8a37fe36ac8b52dca9cc6f2a2e8d6aa372d6/inference.py#L121 should be `~mask` since we want to select all those indices where value is less than p.