-
Notifications
You must be signed in to change notification settings - Fork 10
Added group norm L0 and shifted group norm L0 #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
dpo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @AHsu98! Sorry I must have missed this PR. It looks great, thank you! Here are a few minor comments.
Would you mind adding some using tests though? Thanks!
|
Thank you @AHsu98, could you please rebase and update this PR ? |
+ to .+ and some minor syntax changes Co-authored-by: Dominique <dominique.orban@gmail.com>
…same changes that @dpo suggested on the checking for this groupNormL0 to make groupNormL2 match as well
…in groupNormL0 and groupNormL2 as it was causing tests to error with the way I was checking, and added the groupNormL0 and shiftedGroupNormL0 to the main file. Added groupNormL0 to the tests, and its not erroring, but haven't added the correctness check yet.
|
Alright, I rebased, made some updates (fixed some original mistakes), and incorporated the previous suggestions. It'd be nice to add a check that the indices are non-overlapping (the naive thing I tried worked for the basic case of vectors of vectors of integers as indices, but caused errors in the tests, maybe from shifts, or multiple indices). I also made the suggested changes to the checks on groupNormL2 as well. |
dpo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @AHsu98 ! Sorry for the delay. Just a few cosmetic changes.
Could you please also add some unit tests to valide this code?
| ysum = R(0) | ||
| for (idx, λ) ∈ zip(f.idx, f.lambda) | ||
| yt = norm(x[idx])^2 | ||
| if yt !=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if yt !=0 | |
| if yt > 0 |
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | ||
| lambda::RR | ||
| idx::I | ||
|
|
||
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | ||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||
| new{R, RR, I}(lambda, idx) | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| struct GroupNormL0{R <: Real, RR <: AbstractVector{R}, I} | |
| lambda::RR | |
| idx::I | |
| function GroupNormL0{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | |
| any(lambda .< 0) && error("weights λ must be nonnegative") | |
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | |
| new{R, RR, I}(lambda, idx) | |
| end | |
| end | |
| struct GroupNormL0{R <: Real, V <: AbstractVector{R}, I} | |
| lambda::V | |
| idx::I | |
| function GroupNormL0{R, V, I}(lambda::V, idx::I) where {R <: Real, V <: AbstractVector{R}, I} | |
| any(lambda .< 0) && error("weights λ must be nonnegative") | |
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | |
| new{R, V, I}(lambda, idx) | |
| end | |
| end |
| function prox!( | ||
| y::AbstractArray{R}, | ||
| f::GroupNormL0{R, RR, I}, | ||
| x::AbstractArray{R}, | ||
| γ::R = R(1), | ||
| ) where {R <: Real, RR <: AbstractVector{R}, I} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| function prox!( | |
| y::AbstractArray{R}, | |
| f::GroupNormL0{R, RR, I}, | |
| x::AbstractArray{R}, | |
| γ::R = R(1), | |
| ) where {R <: Real, RR <: AbstractVector{R}, I} | |
| function prox!( | |
| y::AbstractArray{R}, | |
| f::GroupNormL0{R, V, I}, | |
| x::AbstractArray{R}, | |
| γ::R = R(1), | |
| ) where {R <: Real, V <: AbstractVector{R}, I} |
| lambda::RR | ||
| idx::I | ||
|
|
||
| function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| function GroupNormL2{R, RR, I}(lambda::RR, idx::I) where {R <: Real, RR <: AbstractVector{R}, I} | |
| function GroupNormL2{R, V, I}(lambda::V, idx::I) where {R <: Real, V <: AbstractVector{R}, I} |
| end | ||
| any(lambda .< 0) && error("weights λ must be nonnegative") | ||
| length(lambda) != length(idx) && error("number of weights and groups must be the same") | ||
| new{R, RR, I}(lambda, idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| new{R, RR, I}(lambda, idx) | |
| new{R, V, I}(lambda, idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be others in this file. It’s just easier to remember that V stands for “vector”.
|
|
||
| mutable struct ShiftedGroupNormL0{ | ||
| R <: Real, | ||
| RR <: AbstractVector{R}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use V0, V1, V2, V3 here.
First commit towards adding a group L0 norm. Happy to change the name, perhaps L02 is better. This is defined as the L0 norm of the vector of L2 norms of the indices, weighted by the values of lambda. Still needs a bit of testing I think, but pretty simple, only a couple of lines to change from GroupL2.
Also, in both this and GroupL2, we should call out that we require it to be a separable sum (no indices overlap).