Skip to content

Commit a088ea1

Browse files
authored
Merge branch 'master' into coloring
2 parents e869fa3 + 701b096 commit a088ea1

17 files changed

+528
-18
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
77
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
@@ -13,6 +13,7 @@ LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
16+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1617
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
1718
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1819

@@ -22,7 +23,8 @@ julia = "1"
2223
[extras]
2324
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
2425
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
26+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2527
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2628

2729
[targets]
28-
test = ["Test", "DiffEqDiffTools", "IterativeSolvers"]
30+
test = ["Test", "DiffEqDiffTools", "IterativeSolvers", "Random"]

README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
159159
f,
160160
x::AbstractArray{<:Number};
161161
dx = nothing,
162-
color = eachindex(x))
162+
color = eachindex(x),
163+
sparsity = nothing)
163164
```
164165

165166
This call wiil allocate the cache variables each time. To avoid allocating the
@@ -168,7 +169,8 @@ cache, construct the cache in advance:
168169
```julia
169170
ForwardColorJacCache(f,x,_chunksize = nothing;
170171
dx = nothing,
171-
color=1:length(x))
172+
color=1:length(x),
173+
sparsity = nothing)
172174
```
173175

174176
and utilize the following signature:

src/SparseDiffTools.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export contract_color,
2929
numback_hesvec,numback_hesvec!,
3030
autoback_hesvec,autoback_hesvec!,
3131
JacVec,HesVec,HesVecGrad,
32-
Sparsity, sparsity!
32+
Sparsity, sparsity!, hsparsity
3333

3434

3535
include("coloring/high_level.jl")
@@ -44,5 +44,9 @@ include("program_sparsity/program_sparsity.jl")
4444
include("program_sparsity/sparsity_tracker.jl")
4545
include("program_sparsity/path.jl")
4646
include("program_sparsity/take_all_branches.jl")
47+
include("program_sparsity/terms.jl")
48+
include("program_sparsity/linearity.jl")
49+
include("program_sparsity/hessian.jl")
50+
include("program_sparsity/blas.jl")
4751

4852
end # module

src/differentiation/compute_jacobian_ad.jl

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
struct ForwardColorJacCache{T,T2,T3,T4,T5}
1+
struct ForwardColorJacCache{T,T2,T3,T4,T5,T6}
22
t::T
33
fx::T2
44
dx::T3
55
p::T4
66
color::T5
7+
sparsity::T6
78
end
89

910
function default_chunk_size(maxcolor)
@@ -19,7 +20,8 @@ getsize(N::Integer) = N
1920

2021
function ForwardColorJacCache(f,x,_chunksize = nothing;
2122
dx = nothing,
22-
color=1:length(x))
23+
color=1:length(x),
24+
sparsity::Union{SparseMatrixCSC,Nothing}=nothing)
2325

2426
if _chunksize === nothing
2527
chunksize = default_chunk_size(maximum(color))
@@ -38,7 +40,7 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
3840
end
3941

4042
p = generate_chunked_partials(x,color,chunksize)
41-
ForwardColorJacCache(t,fx,_dx,p,color)
43+
ForwardColorJacCache(t,fx,_dx,p,color,sparsity)
4244
end
4345

4446
generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N))
@@ -78,8 +80,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
7880
f,
7981
x::AbstractArray{<:Number};
8082
dx = nothing,
81-
color = eachindex(x))
82-
forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color))
83+
color = eachindex(x),
84+
sparsity = J isa SparseMatrixCSC ? J : nothing)
85+
forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color,sparsity=sparsity))
8386
end
8487

8588
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
@@ -92,15 +95,16 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
9295
dx = jac_cache.dx
9396
p = jac_cache.p
9497
color = jac_cache.color
98+
sparsity = jac_cache.sparsity
9599
color_i = 1
96100
chunksize = length(first(first(jac_cache.p)))
97101

98102
for i in 1:length(p)
99103
partial_i = p[i]
100104
t .= Dual{typeof(f)}.(x, partial_i)
101105
f(fx,t)
102-
if J isa SparseMatrixCSC
103-
rows_index, cols_index, val = findnz(J)
106+
if sparsity isa SparseMatrixCSC
107+
rows_index, cols_index, val = findnz(sparsity)
104108
for j in 1:chunksize
105109
dx .= partials.(fx, j)
106110
for k in 1:length(cols_index)

src/program_sparsity/blas.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using LinearAlgebra
2+
import LinearAlgebra.BLAS
3+
4+
# generic implementations
5+
6+
macro reroute(f, g)
7+
quote
8+
function Cassette.overdub(ctx::HessianSparsityContext,
9+
f::typeof($(esc(f))),
10+
args...)
11+
println("rerouted")
12+
Cassette.overdub(
13+
ctx,
14+
invoke,
15+
$(esc(g.args[1])),
16+
$(esc(:(Tuple{$(g.args[2:end]...)}))),
17+
args...)
18+
end
19+
end
20+
end
21+
22+
@reroute BLAS.dot dot(Any, Any)
23+
@reroute BLAS.axpy! axpy!(Any,
24+
AbstractArray,
25+
AbstractArray)

src/program_sparsity/hessian.jl

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using Cassette
2+
import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse
3+
import Core: SSAValue
4+
using SparseArrays
5+
6+
# Tags:
7+
Cassette.@context HessianSparsityContext
8+
9+
const TaggedOf{T} = Tagged{A, T} where A
10+
11+
const HTagType = Union{Input, TermCombination}
12+
Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType
13+
14+
istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination)
15+
16+
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x)
17+
Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata)
18+
19+
# getindex on the input
20+
function Cassette.overdub(ctx::HessianSparsityContext,
21+
f::typeof(getindex),
22+
X::Tagged,
23+
idx::Tagged...)
24+
if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx)
25+
error("getindex call depends on input. Cannot determine Hessian sparsity")
26+
end
27+
Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...)
28+
end
29+
30+
# plugs an ambiguity
31+
function Cassette.overdub(ctx::HessianSparsityContext,
32+
f::typeof(getindex),
33+
X::Tagged)
34+
Cassette.recurse(ctx, f, X)
35+
end
36+
37+
function Cassette.overdub(ctx::HessianSparsityContext,
38+
f::typeof(getindex),
39+
X::Tagged,
40+
idx::Integer...)
41+
if ismetatype(X, ctx, Input)
42+
val = Cassette.fallback(ctx, f, X, idx...)
43+
i = LinearIndices(untag(X, ctx))[idx...]
44+
tag(val, ctx, TermCombination(Set([Dict(i=>1)])))
45+
else
46+
Cassette.recurse(ctx, f, X, idx...)
47+
end
48+
end
49+
50+
function Cassette.overdub(ctx::HessianSparsityContext,
51+
f::typeof(Base.unsafe_copyto!),
52+
X::Tagged,
53+
xstart,
54+
Y::Tagged,
55+
ystart,
56+
len)
57+
if ismetatype(Y, ctx, Input)
58+
val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len)
59+
nometa = Cassette.NoMetaMeta()
60+
X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1)
61+
val
62+
else
63+
Cassette.recurse(ctx, f, X, xstart, Y, ystart, len)
64+
end
65+
end
66+
function Cassette.overdub(ctx::HessianSparsityContext,
67+
f::typeof(copy),
68+
X::Tagged)
69+
if ismetatype(X, ctx, Input)
70+
val = Cassette.fallback(ctx, f, X)
71+
tag(val, ctx, Input())
72+
else
73+
Cassette.recurse(ctx, f, X)
74+
end
75+
end
76+
77+
combine_terms(::Nothing, terms...) = one(TermCombination)
78+
79+
# 1-arg functions
80+
combine_terms(::Val{true}, term) = term
81+
combine_terms(::Val{false}, term) = term * term
82+
83+
# 2-arg functions
84+
function combine_terms(::Val{linearity}, term1, term2) where linearity
85+
86+
linear11, linear22, linear12 = linearity
87+
term = zero(TermCombination)
88+
if linear11
89+
if !linear12
90+
term += term1
91+
end
92+
else
93+
term += term1 * term1
94+
end
95+
96+
if linear22
97+
if !linear12
98+
term += term2
99+
end
100+
else
101+
term += term2 * term2
102+
end
103+
104+
if linear12
105+
term += term1 + term2
106+
else
107+
term += term1 * term2
108+
end
109+
term
110+
end
111+
112+
113+
# Hessian overdub
114+
#
115+
function getterms(ctx, x)
116+
ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination)
117+
end
118+
119+
function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...)
120+
t = combine_terms(linearity, map(x->getterms(ctx, x), args)...)
121+
val = Cassette.fallback(ctx, f, args...)
122+
tag(val, ctx, t)
123+
end
124+
function Cassette.overdub(ctx::HessianSparsityContext,
125+
f::typeof(getproperty),
126+
x::Tagged, prop)
127+
if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx))
128+
error("property of a non-constant term accessed")
129+
else
130+
Cassette.recurse(ctx, f, x, prop)
131+
end
132+
end
133+
134+
haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs)
135+
linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs)
136+
137+
function Cassette.overdub(ctx::HessianSparsityContext,
138+
f,
139+
args...)
140+
tainted = any(x->ismetatype(x, ctx, TermCombination), args)
141+
val = if tainted && haslinearity(ctx, f, Val{nfields(args)}())
142+
l = linearity(ctx, f, Val{nfields(args)}())
143+
hessian_overdub(ctx, f, l, args...)
144+
else
145+
val = Cassette.recurse(ctx, f, args...)
146+
#if tainted && !ismetatype(val, ctx, TermCombination)
147+
# @warn("Don't know the linearity of function $f")
148+
#end
149+
val
150+
end
151+
val
152+
end

src/program_sparsity/linearity.jl

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using SpecialFunctions
2+
import Base.Broadcast
3+
4+
const constant_funcs = []
5+
6+
const monadic_linear = [deg2rad, +, rad2deg, transpose, -, conj]
7+
8+
const monadic_nonlinear = [asind, log1p, acsch, erfc, digamma, acos, asec, acosh, airybiprime, acsc, cscd, log, tand, log10, csch, asinh, airyai, abs2, gamma, lgamma, erfcx, bessely0, cosh, sin, cos, atan, cospi, cbrt, acosd, bessely1, acoth, erfcinv, erf, dawson, inv, acotd, airyaiprime, erfinv, trigamma, asecd, besselj1, exp, acot, sqrt, sind, sinpi, asech, log2, tan, invdigamma, airybi, exp10, sech, erfi, coth, asin, cotd, cosd, sinh, abs, besselj0, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]
9+
10+
diadic_of_linearity(::Val{(true, true, true)}) = [+, rem2pi, -, >, isless, <, isequal, max, min, convert]
11+
diadic_of_linearity(::Val{(true, true, false)}) = [*]
12+
#diadic_of_linearit(::(Val{(true, false, true)}) = [besselk, hankelh2, bessely, besselj, besseli, polygamma, hankelh1]
13+
diadic_of_linearity(::Val{(true, false, false)}) = [/]
14+
diadic_of_linearity(::Val{(false, true, false)}) = [\]
15+
diadic_of_linearity(::Val{(false, false, false)}) = [hypot, atan, mod, rem, lbeta, ^, beta]
16+
diadic_of_linearity(::Val) = []
17+
18+
haslinearity(f, nargs) = false
19+
20+
# some functions strip the linearity metadata
21+
22+
for f in constant_funcs
23+
@eval begin
24+
haslinearity(::typeof($f), ::Val) = true
25+
linearity(::typeof($f), ::Val) = nothing
26+
end
27+
end
28+
29+
# linearity of a single input function is either
30+
# Val{true}() or Val{false}()
31+
#
32+
for f in monadic_linear
33+
@eval begin
34+
haslinearity(::typeof($f), ::Val{1}) = true
35+
linearity(::typeof($f), ::Val{1}) = Val{true}()
36+
end
37+
end
38+
# linearity of a 2-arg function is:
39+
# Val{(linear11, linear22, linear12)}()
40+
#
41+
# linearIJ refers to the zeroness of d^2/dxIxJ
42+
for f in monadic_nonlinear
43+
@eval begin
44+
haslinearity(::typeof($f), ::Val{1}) = true
45+
linearity(::typeof($f), ::Val{1}) = Val{false}()
46+
end
47+
end
48+
49+
for linearity_mask = 0:2^3-1
50+
lin = Val{map(x->x!=0, (linearity_mask & 4,
51+
linearity_mask & 2,
52+
linearity_mask & 1))}()
53+
54+
for f in diadic_of_linearity(lin)
55+
@eval begin
56+
haslinearity(::typeof($f), ::Val{2}) = true
57+
linearity(::typeof($f), ::Val{2}) = $lin
58+
end
59+
end
60+
end

src/program_sparsity/program_sparsity.jl

+24-1
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,28 @@ function sparsity!(f!, Y, X, args...; sparsity=Sparsity(length(Y), length(X)),
3838
alldone(path) && break
3939
reset!(path)
4040
end
41-
sparsity
41+
sparse(sparsity)
42+
end
43+
44+
function hsparsity(f, X, args...; verbose=true)
45+
46+
terms = zero(TermCombination)
47+
path = Path()
48+
while true
49+
ctx = HessianSparsityContext(metadata=path, pass=BranchesPass)
50+
ctx = Cassette.enabletagging(ctx, f)
51+
ctx = Cassette.disablehooks(ctx)
52+
val = Cassette.recurse(ctx,
53+
f,
54+
tag(X, ctx, Input()),
55+
# TODO: make this recursive
56+
map(arg -> arg isa Fixed ?
57+
arg.value : tag(arg, ctx, one(TermCombination)), args)...)
58+
terms += metadata(val, ctx)
59+
verbose && println("Explored path: ", path)
60+
alldone(path) && break
61+
reset!(path)
62+
end
63+
64+
_sparse(terms, length(X))
4265
end

0 commit comments

Comments
 (0)