Skip to content

Confusing error / silent failure with broadcasted functions with type instability #1439

@DomCRose

Description

@DomCRose

When a function is broadcasted which is type unstable with Dual type inputs, there is a good chance the element type of the resulting output will be abstract, leading to a failure of the logic at

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
(thanks to @ToucheSir who helped debug this). This can then cause an error much later on than the origin of the Dual, after Duals leak into the pullback definition and e.g. the gradient of the output is pulled onto the gradient of the input which assumed a non-Dual eltype, making it confusing to debug. Perhaps even worse, in some cases it causes the gradient to fail silently, either returning nothing or Duals for the gradient.

A MWE of silent failure on 1.9.0, in a temporary environment with only Zygote:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # (nothing,)

In contrast with the expected behaviour of:

using Zygote
f(x) = x > 1.0 ? one(x) : x^2
g(x) = sum(f.(x))
gradient(g, collect(0.5:0.25:1.5)) # ([1.0, 1.5, 2.0, 0.0, 0.0],)

A MWE of error, using repeat with the inner keyword as an example which doesn't allow the Dual to leak:

using Zygote
f(x) = x > 1.0 ? 1.0 : x^2
g(x) = sum(repeat(x, inner=2) .* f.(repeat(x, inner=2)))
gradient(g, collect(0.5:0.25:1.5))

results in:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing, Float64, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
 [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Base .\number.jl:7
 [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{Nothing, Float64, 1}, i1::Int64)
   @ Base .\array.jl:969
 [3] (::Zygote.var"#626#634"{Int64, Vector{Float64}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\lib\array.jl:137
 [4] (::Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}})(Δ::Vector{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Zygote C:\Users\domin\.julia\packages\ZygoteRules\OgCVT\src\adjoint.jl:80
 [5] Pullback
   @ .\REPL[29]:1 [inlined]
 [6] (::Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface2.jl:0
 [7] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}})(Δ::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:45
 [8] gradient(f::Function, args::Vector{Float64})
   @ Zygote C:\Users\domin\.julia\packages\Zygote\JeHtr\src\compiler\interface.jl:97
 [9] top-level scope
   @ REPL[30]:1

which leaves it unclear where the Duals originate from, since the forward pass succeeds with incorrect outputs:

julia> pullback(g, collect(0.5:0.25:1.5))
(Dual{Nothing}(8.59375,7.25), Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(g), Vector{Float64}}, Tuple{Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.var"#2822#back#642"{Zygote.var"#626#634"{Int64, Vector{Float64}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.var"#3027#back#778"{Zygote.var"#772#776"{Vector{Real}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Real}}, Tuple{}}, Zygote.var"#3802#back#1201"{Zygote.var"#1197#1200"{Vector{Float64}, Vector{Real}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,)}}, Tuple{Int64}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:inner,), Tuple{Int64}}}, Tuple{Int64}}, Tuple{Zygote.var"#2224#back#311"{Zygote.Jnew{NamedTuple{(:inner,), Tuple{Int64}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(f), Vector{Float64}}, Tuple{Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4123#back#1356"{Zygote.var"#1386#1388"}}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.var"#2881#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#2173#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1162"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2017#back#200"{typeof(identity)}}}}}}(∂(g)))

In the long run it would be better to fix this, however, in the short term simply adding an error before

T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
when the element type is abstract warning that the function needs to be made type stable on Dual inputs would at least make debugging this much easier. Happy to do a PR adding that.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions