|
476 | 476 | @test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) |
477 | 477 | @test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20 |
478 | 478 |
|
479 | | - @test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0 |
| 479 | + @test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41 |
480 | 480 |
|
481 | 481 | # Array of mutables: |
482 | 482 | @test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3] |
|
490 | 490 | @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) |
491 | 491 | end |
492 | 492 |
|
| 493 | +@testset "mutable accum_param bugs" begin |
| 494 | + mutable struct Mut{T}; x::T; end |
| 495 | + struct Imm{T}; x::T; end |
| 496 | + |
| 497 | + # Indexing a tuple containing a mutable struct gave `nothing` |
| 498 | + x1 = (Mut(3.0),) |
| 499 | + x2 = (Imm(3.0),) |
| 500 | + x3 = (Ref(3.0),) |
| 501 | + @test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 |
| 502 | + @test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),) |
| 503 | + @test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 |
| 504 | + i1 = 1 |
| 505 | + @test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 |
| 506 | + @test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),) |
| 507 | + @test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 |
| 508 | + |
| 509 | + @test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 |
| 510 | + @test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)] |
| 511 | + @test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 |
| 512 | + |
| 513 | + # When `getfield` returns a mutable struct, it gave `nothing`: |
| 514 | + x4 = Imm(Mut(4.0)) |
| 515 | + x5 = Mut(Mut(4.0)) |
| 516 | + x6 = Imm(Imm(4.0)) |
| 517 | + @test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41 |
| 518 | + @test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0 |
| 519 | + @test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41 |
| 520 | + |
| 521 | + @test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41 |
| 522 | + @test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 |
| 523 | + @test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41 |
| 524 | + |
| 525 | + # Check when using implicit parameters, Params cases used to pass: |
| 526 | + y1 = [3.0] |
| 527 | + y2 = (Mut(y1),) |
| 528 | + y3 = (Imm(y1),) |
| 529 | + @test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41 |
| 530 | + @test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0] |
| 531 | + @test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),) |
| 532 | + @test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0] |
| 533 | + |
| 534 | + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 |
| 535 | + @test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] |
| 536 | + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) |
| 537 | + @test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] |
| 538 | + |
| 539 | + i1 = 1 |
| 540 | + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 |
| 541 | + @test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] |
| 542 | + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) |
| 543 | + @test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] |
| 544 | +end |
| 545 | + |
493 | 546 | @testset "NamedTuples" begin |
494 | 547 | @test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),) |
495 | 548 | @test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],) |
|
517 | 570 | @test (x->10*(x => 2)[2])'(100) === nothing |
518 | 571 |
|
519 | 572 | @test gradient(x-> (:x => x)[2], 17) == (1,) |
520 | | - |
| 573 | + |
521 | 574 | d = Dict(:x=>1.0, :y=>3.0); |
522 | 575 | @test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),) |
523 | 576 | end |
|
546 | 599 | # zip |
547 | 600 | if VERSION >= v"1.5" |
548 | 601 | # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch, |
549 | | - # while on 1.5 - 1.7 it stops early. |
| 602 | + # while on 1.5 - 1.7 it stops early. |
550 | 603 |
|
551 | 604 | @test gradient(10:14, 1:10) do xs, ys |
552 | 605 | sum([x/y for (x,y) in zip(xs, ys)]) |
|
608 | 661 |
|
609 | 662 | # Iterators.Product with enumerate |
610 | 663 | @test gradient([2 3; 4 5]) do xs |
611 | | - sum([x^i+y for (i,x) in enumerate(xs), y in xs]) |
| 664 | + sum([x^i+y for (i,x) in enumerate(xs), y in xs]) |
612 | 665 | end == ([8 112; 36 2004],) |
613 | 666 | end |
614 | 667 |
|
|
0 commit comments