Skip to content

Commit 4d0c091

Browse files
committed
Implement Polyester Colored AD for oop functions
1 parent 49b5ab9 commit 4d0c091

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

ext/SparseDiffToolsPolyesterExt.jl

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,69 @@
11
module SparseDiffToolsPolyesterExt
2-
2+
3+
using Adapt, ArrayInterface, ForwardDiff, FiniteDiff, Polyester, SparseDiffTools,
4+
SparseArrays
5+
import SparseDiffTools: polyesterforwarddiff_color_jacobian, ForwardColorJacCache,
6+
__parameterless_type
7+
8+
function cld_fast(a::A, b::B) where {A, B}
9+
T = promote_type(A, B)
10+
return cld_fast(a % T, b % T)
11+
end
12+
function cld_fast(n::T, d::T) where {T}
13+
x = Base.udiv_int(n, d)
14+
x += n != d * x
15+
return x
16+
end
17+
18+
function polyesterforwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
19+
x::AbstractArray{<:Number}, jac_cache::ForwardColorJacCache) where {F}
20+
t = jac_cache.t
21+
dx = jac_cache.dx
22+
p = jac_cache.p
23+
colorvec = jac_cache.colorvec
24+
sparsity = jac_cache.sparsity
25+
chunksize = jac_cache.chunksize
26+
maxcolor = maximum(colorvec)
27+
28+
vecx = vec(x)
29+
30+
nrows, ncols = size(J)
31+
32+
if !(sparsity isa Nothing)
33+
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
34+
rows_index = [rows_index[i] for i in 1:length(rows_index)]
35+
cols_index = [cols_index[i] for i in 1:length(cols_index)]
36+
else
37+
rows_index = 1:nrows
38+
cols_index = 1:ncols
39+
end
40+
41+
if J isa AbstractSparseMatrix
42+
fill!(nonzeros(J), zero(eltype(J)))
43+
else
44+
fill!(J, zero(eltype(J)))
45+
end
46+
47+
batch((length(p), min(length(p), Threads.nthreads()))) do _, start, stop
48+
for i in start:stop
49+
partial_i = p[i]
50+
color_i = i
51+
t_ = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)), size(t))
52+
fx = f(t_)
53+
for j in 1:chunksize
54+
dx = vec(ForwardDiff.partials.(fx, j))
55+
pick_inds = [idx
56+
for idx in 1:length(rows_index)
57+
if colorvec[cols_index[idx]] == color_i]
58+
rows_index_c = rows_index[pick_inds]
59+
cols_index_c = cols_index[pick_inds]
60+
@inbounds @simd for i in 1:length(rows_index_c)
61+
J[rows_index_c[i], cols_index_c[i]] = dx[rows_index_c[i]]
62+
end
63+
end
64+
end
65+
end
66+
return J
67+
end
68+
369
end

ext/SparseDiffToolsPolyesterForwardDiffExt.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import ForwardDiff
55
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
66
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache,
77
sparse_jacobian!,
8-
sparse_jacobian_static_array, __standard_tag, __chunksize
8+
sparse_jacobian_static_array, __standard_tag, __chunksize,
9+
polyesterforwarddiff_color_jacobian,
10+
polyesterforwarddiff_color_jacobian!
911

1012
struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
1113
AbstractMaybeSparseJacobianCache
@@ -25,8 +27,6 @@ function sparse_jacobian_cache(
2527
cache = __chunksize(ad, x)
2628
jac_prototype = nothing
2729
else
28-
@warn """Currently PolyesterForwardDiff does not support sparsity detection
29-
natively. Falling back to using ForwardDiff.jl""" maxlog=1
3030
tag = __standard_tag(nothing, x)
3131
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
3232
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
@@ -45,7 +45,8 @@ function sparse_jacobian_cache(
4545
jac_prototype = nothing
4646
else
4747
@warn """Currently PolyesterForwardDiff does not support sparsity detection
48-
natively. Falling back to using ForwardDiff.jl""" maxlog=1
48+
natively for inplace functions. Falling back to using
49+
ForwardDiff.jl""" maxlog=1
4950
tag = __standard_tag(nothing, x)
5051
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
5152
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
@@ -58,7 +59,7 @@ end
5859
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
5960
f::F, x) where {F}
6061
if cache.cache isa ForwardColorJacCache
61-
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
62+
polyesterforwarddiff_color_jacobian(J, f, x, cache.cache)
6263
else
6364
PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity
6465
end
@@ -68,7 +69,7 @@ end
6869
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
6970
f!::F, fx, x) where {F}
7071
if cache.cache isa ForwardColorJacCache
71-
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
72+
forwarddiff_color_jacobian!(J, f!, x, cache.cache)
7273
else
7374
PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity
7475
end

src/differentiation/compute_jacobian_ad.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ function forwarddiff_color_jacobian(f::F, x::AbstractArray{<:Number},
173173
end
174174
end
175175

176+
# Defined in extension. Polyester version of `forwarddiff_color_jacobian`
177+
function polyesterforwarddiff_color_jacobian end
178+
176179
# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
177180
function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
178181
x::AbstractArray{<:Number},
@@ -249,9 +252,8 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number}, f::F,
249252
end
250253

251254
# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
252-
function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
253-
jac_cache::ForwardColorJacCache,
254-
jac_prototype = nothing)
255+
function forwarddiff_color_jacobian_immutable(f::F, x::AbstractArray{<:Number},
256+
jac_cache::ForwardColorJacCache, jac_prototype = nothing) where {F}
255257
t = jac_cache.t
256258
dx = jac_cache.dx
257259
p = jac_cache.p
@@ -315,16 +317,16 @@ function forwarddiff_color_jacobian_immutable(f, x::AbstractArray{<:Number},
315317
return J
316318
end
317319

318-
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f,
320+
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, f::F,
319321
x::AbstractArray{<:Number}; dx = similar(x, size(J, 1)), colorvec = 1:length(x),
320-
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing)
322+
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing) where {F}
321323
forwarddiff_color_jacobian!(J, f, x, ForwardColorJacCache(f, x; dx, colorvec, sparsity))
322324
end
323325

324326
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
325-
f,
327+
f::F,
326328
x::AbstractArray{<:Number},
327-
jac_cache::ForwardColorJacCache)
329+
jac_cache::ForwardColorJacCache) where {F}
328330
t = jac_cache.t
329331
fx = jac_cache.fx
330332
dx = jac_cache.dx

0 commit comments

Comments
 (0)