Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 2b7aae7

Browse files
Merge pull request #272 from avik-pal/ap/fix_nonsq_vecjac
Rework VecJac Operator
2 parents 7d23bec + 635828a commit 2b7aae7

File tree

4 files changed

+183
-107
lines changed

4 files changed

+183
-107
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.9.2"
4+
version = "2.10.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/SparseDiffToolsZygoteExt.jl

+20-30
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
1313
import ADTypes: AutoZygote, AutoSparseZygote
1414

1515
## Satisfying High-Level Interface for Sparse Jacobians
16-
function __gradient(::Union{AutoSparseZygote, AutoZygote}, f, x, cols)
16+
function __gradient(::Union{AutoSparseZygote, AutoZygote}, f::F, x, cols) where {F}
1717
_, ∂x, _ = Zygote.gradient(__f̂, f, x, cols)
1818
return vec(∂x)
1919
end
2020

21-
function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!, fx, x, cols)
21+
function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x, cols) where {F}
2222
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
2323
end
2424

2525
# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
2626
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
2727
import Zygote: _jvec, _eyelike, _gradcopy!
2828

29-
@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f, x)
29+
@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f::F,
30+
x) where {F}
3031
y, back = Zygote.pullback(_jvec f, x)
3132
δ = _eyelike(y)
3233
for k in LinearIndices(y)
@@ -36,13 +37,13 @@ import Zygote: _jvec, _eyelike, _gradcopy!
3637
return J
3738
end
3839

39-
function __jacobian!(J, ::Union{AutoSparseZygote, AutoZygote}, f!, fx, x)
40+
function __jacobian!(_, ::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x) where {F}
4041
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
4142
end
4243

4344
### Jac, Hes products
4445

45-
function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
46+
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v)) where {F}
4647
g = let f = f
4748
(dx, x) -> dx .= first(Zygote.gradient(f, x))
4849
end
@@ -57,15 +58,14 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
5758
@. dy = (cache1 - cache2) / (2ϵ)
5859
end
5960

60-
function numback_hesvec(f, x, v)
61-
g = x -> first(Zygote.gradient(f, x))
61+
function numback_hesvec(f::F, x, v) where {F}
6262
T = eltype(x)
6363
# Should it be min? max? mean?
6464
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
6565
x += ϵ * v
66-
gxp = g(x)
66+
gxp = first(Zygote.gradient(f, x))
6767
x -= 2ϵ * v
68-
gxm = g(x)
68+
gxm = first(Zygote.gradient(f, x))
6969
(gxp - gxm) / (2ϵ)
7070
end
7171

@@ -94,61 +94,51 @@ end
9494
## VecJac products
9595

9696
# VJP methods
97-
function auto_vecjac!(du, f, x, v)
97+
function auto_vecjac!(du, f::F, x, v) where {F}
9898
!static_hasmethod(f, typeof((x,))) &&
9999
error("For inplace function use autodiff = AutoFiniteDiff()")
100100
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
101101
end
102102

103-
function auto_vecjac(f, x, v)
103+
function auto_vecjac(f::F, x, v) where {F}
104104
y, back = Zygote.pullback(f, x)
105-
return vec(back(reshape(v, size(y)))[1])
105+
return vec(only(back(reshape(v, size(y)))))
106106
end
107107

108108
# overload operator interface
109-
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
110-
cache = ()
109+
function SparseDiffTools._vecjac(f::F, _, u, autodiff::AutoZygote) where {F}
110+
!static_hasmethod(f, typeof((u,))) &&
111+
error("For inplace function use autodiff = AutoFiniteDiff()")
111112
pullback = Zygote.pullback(f, u)
112-
113-
return AutoDiffVJP(f, u, cache, autodiff, pullback)
113+
return AutoDiffVJP(f, u, (), autodiff, pullback)
114114
end
115115

116116
function update_coefficients(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing)
117117
VJP_input !== nothing && (@set! L.u = VJP_input)
118-
119118
@set! L.f = update_coefficients(L.f, L.u, p, t)
120119
@set! L.pullback = Zygote.pullback(L.f, L.u)
120+
return L
121121
end
122122

123123
function update_coefficients!(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing)
124124
VJP_input !== nothing && copy!(L.u, VJP_input)
125-
126125
update_coefficients!(L.f, L.u, p, t)
127126
L.pullback = Zygote.pullback(L.f, L.u)
128-
129127
return L
130128
end
131129

132130
# Interpret the call as df/du' * v
133131
function (L::AutoDiffVJP{<:AutoZygote})(v, p, t; VJP_input = nothing)
134132
# ignore VJP_input as pullback was computed in update_coefficients(...)
135133
y, back = L.pullback
136-
V = reshape(v, size(y))
137-
138-
return vec(first(back(V)))
134+
return vec(only(back(reshape(v, size(y)))))
139135
end
140136

141137
# prefer non in-place method
142-
function (L::AutoDiffVJP{<:AutoZygote, IIP, true})(dv, v, p, t;
143-
VJP_input = nothing) where {IIP}
138+
function (L::AutoDiffVJP{<:AutoZygote})(dv, v, p, t; VJP_input = nothing)
144139
# ignore VJP_input as pullback was computed in update_coefficients!(...)
145-
146-
_dv = L(v, p, t; VJP_input = VJP_input)
140+
_dv = L(v, p, t; VJP_input)
147141
copy!(dv, _dv)
148142
end
149143

150-
function (L::AutoDiffVJP{<:AutoZygote, true, false})(args...; kwargs...)
151-
error("Zygote requires an out of place method with signature f(u).")
152-
end
153-
154144
end # module

0 commit comments

Comments
 (0)