https://github.com/vandit15/Class-balanced-loss-pytorch/blob/921ccb8725b1eb0903b2c22a1a752a594fcae138/class_balanced_loss.py#L28 should be `alpha: A float tensor of size [num_classes]`