|
| 1 | +using Cassette |
| 2 | +import Cassette: tag, untag, Tagged, metadata, hasmetadata, istagged, canrecurse |
| 3 | +import Core: SSAValue |
| 4 | +using SparseArrays |
| 5 | + |
| 6 | +# Tags: |
| 7 | +Cassette.@context HessianSparsityContext |
| 8 | + |
| 9 | +const TaggedOf{T} = Tagged{A, T} where A |
| 10 | + |
| 11 | +const HTagType = Union{Input, TermCombination} |
| 12 | +Cassette.metadatatype(::Type{<:HessianSparsityContext}, ::DataType) = HTagType |
| 13 | + |
| 14 | +istainted(ctx::HessianSparsityContext, x) = ismetatype(x, ctx, TermCombination) |
| 15 | + |
| 16 | +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(istainted), x) = istainted(ctx, x) |
| 17 | +Cassette.overdub(ctx::HessianSparsityContext, f::typeof(this_here_predicate!)) = this_here_predicate!(ctx.metadata) |
| 18 | + |
| 19 | +# getindex on the input |
| 20 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 21 | + f::typeof(getindex), |
| 22 | + X::Tagged, |
| 23 | + idx::Tagged...) |
| 24 | + if any(i->ismetatype(i, ctx, TermCombination) && !isone(metadata(i, ctx)), idx) |
| 25 | + error("getindex call depends on input. Cannot determine Hessian sparsity") |
| 26 | + end |
| 27 | + Cassette.overdub(ctx, f, X, map(i->untag(i, ctx), idx)...) |
| 28 | +end |
| 29 | + |
| 30 | +# plugs an ambiguity |
| 31 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 32 | + f::typeof(getindex), |
| 33 | + X::Tagged) |
| 34 | + Cassette.recurse(ctx, f, X) |
| 35 | +end |
| 36 | + |
| 37 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 38 | + f::typeof(getindex), |
| 39 | + X::Tagged, |
| 40 | + idx::Integer...) |
| 41 | + if ismetatype(X, ctx, Input) |
| 42 | + val = Cassette.fallback(ctx, f, X, idx...) |
| 43 | + i = LinearIndices(untag(X, ctx))[idx...] |
| 44 | + tag(val, ctx, TermCombination(Set([Dict(i=>1)]))) |
| 45 | + else |
| 46 | + Cassette.recurse(ctx, f, X, idx...) |
| 47 | + end |
| 48 | +end |
| 49 | + |
| 50 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 51 | + f::typeof(Base.unsafe_copyto!), |
| 52 | + X::Tagged, |
| 53 | + xstart, |
| 54 | + Y::Tagged, |
| 55 | + ystart, |
| 56 | + len) |
| 57 | + if ismetatype(Y, ctx, Input) |
| 58 | + val = Cassette.fallback(ctx, f, X, xstart, Y, ystart, len) |
| 59 | + nometa = Cassette.NoMetaMeta() |
| 60 | + X.meta.meta[xstart:xstart+len-1] .= (i->Cassette.Meta(TermCombination(Set([Dict(i=>1)])), nometa)).(ystart:ystart+len-1) |
| 61 | + val |
| 62 | + else |
| 63 | + Cassette.recurse(ctx, f, X, xstart, Y, ystart, len) |
| 64 | + end |
| 65 | +end |
| 66 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 67 | + f::typeof(copy), |
| 68 | + X::Tagged) |
| 69 | + if ismetatype(X, ctx, Input) |
| 70 | + val = Cassette.fallback(ctx, f, X) |
| 71 | + tag(val, ctx, Input()) |
| 72 | + else |
| 73 | + Cassette.recurse(ctx, f, X) |
| 74 | + end |
| 75 | +end |
| 76 | + |
| 77 | +combine_terms(::Nothing, terms...) = one(TermCombination) |
| 78 | + |
| 79 | +# 1-arg functions |
| 80 | +combine_terms(::Val{true}, term) = term |
| 81 | +combine_terms(::Val{false}, term) = term * term |
| 82 | + |
| 83 | +# 2-arg functions |
| 84 | +function combine_terms(::Val{linearity}, term1, term2) where linearity |
| 85 | + |
| 86 | + linear11, linear22, linear12 = linearity |
| 87 | + term = zero(TermCombination) |
| 88 | + if linear11 |
| 89 | + if !linear12 |
| 90 | + term += term1 |
| 91 | + end |
| 92 | + else |
| 93 | + term += term1 * term1 |
| 94 | + end |
| 95 | + |
| 96 | + if linear22 |
| 97 | + if !linear12 |
| 98 | + term += term2 |
| 99 | + end |
| 100 | + else |
| 101 | + term += term2 * term2 |
| 102 | + end |
| 103 | + |
| 104 | + if linear12 |
| 105 | + term += term1 + term2 |
| 106 | + else |
| 107 | + term += term1 * term2 |
| 108 | + end |
| 109 | + term |
| 110 | +end |
| 111 | + |
| 112 | + |
| 113 | +# Hessian overdub |
| 114 | +# |
| 115 | +function getterms(ctx, x) |
| 116 | + ismetatype(x, ctx, TermCombination) ? metadata(x, ctx) : one(TermCombination) |
| 117 | +end |
| 118 | + |
| 119 | +function hessian_overdub(ctx::HessianSparsityContext, f, linearity, args...) |
| 120 | + t = combine_terms(linearity, map(x->getterms(ctx, x), args)...) |
| 121 | + val = Cassette.fallback(ctx, f, args...) |
| 122 | + tag(val, ctx, t) |
| 123 | +end |
| 124 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 125 | + f::typeof(getproperty), |
| 126 | + x::Tagged, prop) |
| 127 | + if ismetatype(x, ctx, TermCombination) && !isone(metadata(x, ctx)) |
| 128 | + error("property of a non-constant term accessed") |
| 129 | + else |
| 130 | + Cassette.recurse(ctx, f, x, prop) |
| 131 | + end |
| 132 | +end |
| 133 | + |
| 134 | +haslinearity(ctx::HessianSparsityContext, f, nargs) = haslinearity(untag(f, ctx), nargs) |
| 135 | +linearity(ctx::HessianSparsityContext, f, nargs) = linearity(untag(f, ctx), nargs) |
| 136 | + |
| 137 | +function Cassette.overdub(ctx::HessianSparsityContext, |
| 138 | + f, |
| 139 | + args...) |
| 140 | + tainted = any(x->ismetatype(x, ctx, TermCombination), args) |
| 141 | + val = if tainted && haslinearity(ctx, f, Val{nfields(args)}()) |
| 142 | + l = linearity(ctx, f, Val{nfields(args)}()) |
| 143 | + hessian_overdub(ctx, f, l, args...) |
| 144 | + else |
| 145 | + val = Cassette.recurse(ctx, f, args...) |
| 146 | + #if tainted && !ismetatype(val, ctx, TermCombination) |
| 147 | + # @warn("Don't know the linearity of function $f") |
| 148 | + #end |
| 149 | + val |
| 150 | + end |
| 151 | + val |
| 152 | +end |
0 commit comments