Skip to content

[BUG] Bugs in the f function of cw.py #184

@EthanChu7

Description

@EthanChu7

✨ Short description of the bug [tl;dr]

Bugs in the f function of cw.py when facing negative logits

💬 Detailed code and results

In f function of cw.py, the real and other are computed by

other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
real = torch.max(one_hot_labels * outputs, dim=1)[0]

Howerver, when facing negative logits(negative other label logits in other and negative target label in real), the other and real become zero.

I suggest to make the following modifications:

other = torch.max((1 - one_hot_labels) * outputs - one_hot_labels * 1e4, dim=1)[0] 
real = torch.sum(one_hot_labels*outputs, dim=1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions