@@ -13,20 +13,21 @@ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
13
13
import ADTypes: AutoZygote, AutoSparseZygote
14
14
15
15
# # Satisfying High-Level Interface for Sparse Jacobians
16
- function __gradient (:: Union{AutoSparseZygote, AutoZygote} , f, x, cols)
16
+ function __gradient (:: Union{AutoSparseZygote, AutoZygote} , f:: F , x, cols) where {F}
17
17
_, ∂x, _ = Zygote. gradient (__f̂, f, x, cols)
18
18
return vec (∂x)
19
19
end
20
20
21
- function __gradient! (:: Union{AutoSparseZygote, AutoZygote} , f!, fx, x, cols)
21
+ function __gradient! (:: Union{AutoSparseZygote, AutoZygote} , f!:: F , fx, x, cols) where {F}
22
22
return error (" Zygote.jl cannot differentiate in-place (mutating) functions." )
23
23
end
24
24
25
25
# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
26
26
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
27
27
import Zygote: _jvec, _eyelike, _gradcopy!
28
28
29
- @views function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseZygote, AutoZygote} , f, x)
29
+ @views function __jacobian! (J:: AbstractMatrix , :: Union{AutoSparseZygote, AutoZygote} , f:: F ,
30
+ x) where {F}
30
31
y, back = Zygote. pullback (_jvec ∘ f, x)
31
32
δ = _eyelike (y)
32
33
for k in LinearIndices (y)
@@ -36,13 +37,13 @@ import Zygote: _jvec, _eyelike, _gradcopy!
36
37
return J
37
38
end
38
39
39
- function __jacobian! (J , :: Union{AutoSparseZygote, AutoZygote} , f!, fx, x)
40
+ function __jacobian! (_ , :: Union{AutoSparseZygote, AutoZygote} , f!:: F , fx, x) where {F}
40
41
return error (" Zygote.jl cannot differentiate in-place (mutating) functions." )
41
42
end
42
43
43
44
# ## Jac, Hes products
44
45
45
- function numback_hesvec! (dy, f, x, v, cache1 = similar (v), cache2 = similar (v))
46
+ function numback_hesvec! (dy, f:: F , x, v, cache1 = similar (v), cache2 = similar (v)) where {F}
46
47
g = let f = f
47
48
(dx, x) -> dx .= first (Zygote. gradient (f, x))
48
49
end
@@ -57,15 +58,14 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
57
58
@. dy = (cache1 - cache2) / (2 ϵ)
58
59
end
59
60
60
- function numback_hesvec (f, x, v)
61
- g = x -> first (Zygote. gradient (f, x))
61
+ function numback_hesvec (f:: F , x, v) where {F}
62
62
T = eltype (x)
63
63
# Should it be min? max? mean?
64
64
ϵ = sqrt (eps (real (T))) * max (one (real (T)), abs (norm (x)))
65
65
x += ϵ * v
66
- gxp = g (x )
66
+ gxp = first (Zygote . gradient (f, x) )
67
67
x -= 2 ϵ * v
68
- gxm = g (x )
68
+ gxm = first (Zygote . gradient (f, x) )
69
69
(gxp - gxm) / (2 ϵ)
70
70
end
71
71
94
94
# # VecJac products
95
95
96
96
# VJP methods
97
- function auto_vecjac! (du, f, x, v)
97
+ function auto_vecjac! (du, f:: F , x, v) where {F}
98
98
! static_hasmethod (f, typeof ((x,))) &&
99
99
error (" For inplace function use autodiff = AutoFiniteDiff()" )
100
100
du .= reshape (SparseDiffTools. auto_vecjac (f, x, v), size (du))
101
101
end
102
102
103
- function auto_vecjac (f, x, v)
103
+ function auto_vecjac (f:: F , x, v) where {F}
104
104
y, back = Zygote. pullback (f, x)
105
- return vec (back (reshape (v, size (y)))[ 1 ] )
105
+ return vec (only ( back (reshape (v, size (y)))) )
106
106
end
107
107
108
108
# overload operator interface
109
- function SparseDiffTools. _vecjac (f, u, autodiff:: AutoZygote )
110
- cache = ()
109
+ function SparseDiffTools. _vecjac (f:: F , _, u, autodiff:: AutoZygote ) where {F}
110
+ ! static_hasmethod (f, typeof ((u,))) &&
111
+ error (" For inplace function use autodiff = AutoFiniteDiff()" )
111
112
pullback = Zygote. pullback (f, u)
112
-
113
- return AutoDiffVJP (f, u, cache, autodiff, pullback)
113
+ return AutoDiffVJP (f, u, (), autodiff, pullback)
114
114
end
115
115
116
116
function update_coefficients (L:: AutoDiffVJP{<:AutoZygote} , u, p, t; VJP_input = nothing )
117
117
VJP_input != = nothing && (@set! L. u = VJP_input)
118
-
119
118
@set! L. f = update_coefficients (L. f, L. u, p, t)
120
119
@set! L. pullback = Zygote. pullback (L. f, L. u)
120
+ return L
121
121
end
122
122
123
123
function update_coefficients! (L:: AutoDiffVJP{<:AutoZygote} , u, p, t; VJP_input = nothing )
124
124
VJP_input != = nothing && copy! (L. u, VJP_input)
125
-
126
125
update_coefficients! (L. f, L. u, p, t)
127
126
L. pullback = Zygote. pullback (L. f, L. u)
128
-
129
127
return L
130
128
end
131
129
132
130
# Interpret the call as df/du' * v
133
131
function (L:: AutoDiffVJP{<:AutoZygote} )(v, p, t; VJP_input = nothing )
134
132
# ignore VJP_input as pullback was computed in update_coefficients(...)
135
133
y, back = L. pullback
136
- V = reshape (v, size (y))
137
-
138
- return vec (first (back (V)))
134
+ return vec (only (back (reshape (v, size (y)))))
139
135
end
140
136
141
137
# prefer non in-place method
142
- function (L:: AutoDiffVJP{<:AutoZygote, IIP, true} )(dv, v, p, t;
143
- VJP_input = nothing ) where {IIP}
138
+ function (L:: AutoDiffVJP{<:AutoZygote} )(dv, v, p, t; VJP_input = nothing )
144
139
# ignore VJP_input as pullback was computed in update_coefficients!(...)
145
-
146
- _dv = L (v, p, t; VJP_input = VJP_input)
140
+ _dv = L (v, p, t; VJP_input)
147
141
copy! (dv, _dv)
148
142
end
149
143
150
- function (L:: AutoDiffVJP{<:AutoZygote, true, false} )(args... ; kwargs... )
151
- error (" Zygote requires an out of place method with signature f(u)." )
152
- end
153
-
154
144
end # module
0 commit comments