-
Couldn't load subscription status.
- Fork 367
Open
Labels
bugSomething isn't workingSomething isn't working
Description
✨ Short description of the bug [tl;dr]
Bugs in the f function of cw.py when facing negative logits
💬 Detailed code and results
In f function of cw.py, the real and other are computed by
other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
real = torch.max(one_hot_labels * outputs, dim=1)[0]
Howerver, when facing negative logits(negative other label logits in other and negative target label in real), the other and real become zero.
I suggest to make the following modifications:
other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels * 1e4, dim=1)[0]
real = torch.sum(one_hot_labels*outputs, dim=1)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working