|
1 |
| -using Random |
2 |
| - |
3 |
| -### integrator.jl |
4 |
| - |
5 |
| -import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step |
6 |
| -using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size |
7 |
| - |
8 |
| -""" |
9 |
| -$(TYPEDEF) |
10 |
| -
|
11 |
| -Generalized leapfrog integrator with fixed step size `ϵ`. |
12 |
| -
|
13 |
| -# Fields |
14 |
| -
|
15 |
| -$(TYPEDFIELDS) |
16 |
| -""" |
17 |
| -struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T} |
18 |
| - "Step size." |
19 |
| - ϵ::T |
20 |
| - n::Int |
21 |
| -end |
22 |
| -function Base.show(io::IO, l::GeneralizedLeapfrog) |
23 |
| - return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))") |
24 |
| -end |
25 |
| - |
26 |
| -# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ |
27 |
| -function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T} |
28 |
| - dv = ∂H∂θ(h, θ, r) |
29 |
| - return return_cache ? (dv, nothing) : dv |
30 |
| -end |
31 |
| - |
32 |
| -# TODO Make sure vectorization works |
33 |
| -# TODO Check if tempering is valid |
34 |
| -function step( |
35 |
| - lf::GeneralizedLeapfrog{T}, |
36 |
| - h::Hamiltonian, |
37 |
| - z::P, |
38 |
| - n_steps::Int=1; |
39 |
| - fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0 |
40 |
| - full_trajectory::Val{FullTraj}=Val(false), |
41 |
| -) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj} |
42 |
| - n_steps = abs(n_steps) # to support `n_steps < 0` cases |
43 |
| - |
44 |
| - ϵ = fwd ? step_size(lf) : -step_size(lf) |
45 |
| - ϵ = ϵ' |
46 |
| - |
47 |
| - res = if FullTraj |
48 |
| - Vector{P}(undef, n_steps) |
49 |
| - else |
50 |
| - z |
51 |
| - end |
52 |
| - |
53 |
| - for i in 1:n_steps |
54 |
| - θ_init, r_init = z.θ, z.r |
55 |
| - # Tempering |
56 |
| - #r = temper(lf, r, (i=i, is_half=true), n_steps) |
57 |
| - #! Eq (16) of Girolami & Calderhead (2011) |
58 |
| - r_half = copy(r_init) |
59 |
| - local cache |
60 |
| - for j in 1:(lf.n) |
61 |
| - # Reuse cache for the first iteration |
62 |
| - if j == 1 |
63 |
| - (; value, gradient) = z.ℓπ |
64 |
| - elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged) |
65 |
| - retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true) |
66 |
| - (; value, gradient) = retval |
67 |
| - else # reuse cache |
68 |
| - (; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache) |
69 |
| - end |
70 |
| - r_half = r_init - ϵ / 2 * gradient |
71 |
| - # println("r_half: ", r_half) |
72 |
| - end |
73 |
| - #! Eq (17) of Girolami & Calderhead (2011) |
74 |
| - θ_full = copy(θ_init) |
75 |
| - term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop |
76 |
| - for j in 1:(lf.n) |
77 |
| - θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half)) |
78 |
| - # println("θ_full :", θ_full) |
79 |
| - end |
80 |
| - #! Eq (18) of Girolami & Calderhead (2011) |
81 |
| - (; value, gradient) = ∂H∂θ(h, θ_full, r_half) |
82 |
| - r_full = r_half - ϵ / 2 * gradient |
83 |
| - # println("r_full: ", r_full) |
84 |
| - # Tempering |
85 |
| - #r = temper(lf, r, (i=i, is_half=false), n_steps) |
86 |
| - # Create a new phase point by caching the logdensity and gradient |
87 |
| - z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient)) |
88 |
| - # Update result |
89 |
| - if FullTraj |
90 |
| - res[i] = z |
91 |
| - else |
92 |
| - res = z |
93 |
| - end |
94 |
| - if !isfinite(z) |
95 |
| - # Remove undef |
96 |
| - if FullTraj |
97 |
| - res = res[isassigned.(Ref(res), 1:n_steps)] |
98 |
| - end |
99 |
| - break |
100 |
| - end |
101 |
| - # @assert false |
102 |
| - end |
103 |
| - return res |
104 |
| -end |
105 |
| - |
106 |
| -# TODO Make the order of θ and r consistent with neg_energy |
107 |
| -∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ) |
108 |
| -∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r) |
109 |
| - |
110 |
| -### hamiltonian.jl |
111 |
| - |
112 |
| -import AdvancedHMC: refresh, phasepoint |
113 |
| -using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric |
114 |
| - |
115 |
| -# To change L180 of hamiltonian.jl |
116 |
| -function phasepoint( |
117 |
| - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
118 |
| - θ::AbstractVecOrMat{T}, |
119 |
| - h::Hamiltonian, |
120 |
| -) where {T<:Real} |
121 |
| - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) |
122 |
| -end |
123 |
| - |
124 |
| -# To change L191 of hamiltonian.jl |
125 |
| -function refresh( |
126 |
| - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
127 |
| - ::FullMomentumRefreshment, |
128 |
| - h::Hamiltonian, |
129 |
| - z::PhasePoint, |
130 |
| -) |
131 |
| - return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ)) |
132 |
| -end |
133 |
| - |
134 |
| -# To change L215 of hamiltonian.jl |
135 |
| -function refresh( |
136 |
| - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
137 |
| - ref::PartialMomentumRefreshment, |
138 |
| - h::Hamiltonian, |
139 |
| - z::PhasePoint, |
140 |
| -) |
141 |
| - return phasepoint( |
142 |
| - h, |
143 |
| - z.θ, |
144 |
| - ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ), |
145 |
| - ) |
146 |
| -end |
147 |
| - |
148 |
| -### metric.jl |
149 |
| - |
150 |
| -import AdvancedHMC: _rand |
151 |
| -using AdvancedHMC: AbstractMetric |
152 |
| -using LinearAlgebra: eigen, cholesky, Symmetric |
153 |
| - |
154 |
| -abstract type AbstractRiemannianMetric <: AbstractMetric end |
155 |
| - |
156 |
| -abstract type AbstractHessianMap end |
157 |
| - |
158 |
| -struct IdentityMap <: AbstractHessianMap end |
159 |
| - |
160 |
| -(::IdentityMap)(x) = x |
161 |
| - |
162 |
| -struct SoftAbsMap{T} <: AbstractHessianMap |
163 |
| - α::T |
164 |
| -end |
165 |
| - |
166 |
| -# TODO Register softabs with ReverseDiff |
167 |
| -#! The definition of SoftAbs from Page 3 of Betancourt (2012) |
168 |
| -function softabs(X, α=20.0) |
169 |
| - F = eigen(X) # ReverseDiff cannot diff through `eigen` |
170 |
| - Q = hcat(F.vectors) |
171 |
| - λ = F.values |
172 |
| - softabsλ = λ .* coth.(α * λ) |
173 |
| - return Q * diagm(softabsλ) * Q', Q, λ, softabsλ |
174 |
| -end |
175 |
| - |
176 |
| -(map::SoftAbsMap)(x) = softabs(x, map.α)[1] |
177 |
| - |
178 |
| -struct DenseRiemannianMetric{ |
179 |
| - T, |
180 |
| - TM<:AbstractHessianMap, |
181 |
| - A<:Union{Tuple{Int},Tuple{Int,Int}}, |
182 |
| - AV<:AbstractVecOrMat{T}, |
183 |
| - TG, |
184 |
| - T∂G∂θ, |
185 |
| -} <: AbstractRiemannianMetric |
186 |
| - size::A |
187 |
| - G::TG # TODO store G⁻¹ here instead |
188 |
| - ∂G∂θ::T∂G∂θ |
189 |
| - map::TM |
190 |
| - _temp::AV |
191 |
| -end |
192 |
| - |
193 |
| -# TODO Make dense mass matrix support matrix-mode parallel |
194 |
| -function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat} |
195 |
| - _temp = Vector{Float64}(undef, size[1]) |
196 |
| - return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp) |
197 |
| -end |
198 |
| -# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D)) |
199 |
| -# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D) |
200 |
| -# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz))) |
201 |
| -# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz) |
202 |
| - |
203 |
| -# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹) |
204 |
| - |
205 |
| -Base.size(e::DenseRiemannianMetric) = e.size |
206 |
| -Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim] |
207 |
| -Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)") |
208 |
| - |
209 |
| -function rand_momentum( |
210 |
| - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, |
211 |
| - metric::DenseRiemannianMetric{T}, |
212 |
| - kinetic, |
| 1 | +#! Eq (14) of Girolami & Calderhead (2011) |
| 2 | +function ∂H∂r( |
| 3 | + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, |
213 | 4 | θ::AbstractVecOrMat,
|
214 |
| -) where {T} |
215 |
| - r = _randn(rng, T, size(metric)...) |
216 |
| - G⁻¹ = inv(metric.map(metric.G(θ))) |
217 |
| - chol = cholesky(Symmetric(G⁻¹)) |
218 |
| - ldiv!(chol.U, r) |
219 |
| - return r |
220 |
| -end |
221 |
| - |
222 |
| -### hamiltonian.jl |
223 |
| - |
224 |
| -import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r |
225 |
| -using LinearAlgebra: logabsdet, tr |
226 |
| - |
227 |
| -# QUES Do we want to change everything to position dependent by default? |
228 |
| -# Add θ to ∂H∂r for DenseRiemannianMetric |
229 |
| -function phasepoint( |
230 |
| - h::Hamiltonian{<:DenseRiemannianMetric}, |
231 |
| - θ::T, |
232 |
| - r::T; |
233 |
| - ℓπ=∂H∂θ(h, θ), |
234 |
| - ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), |
235 |
| -) where {T<:AbstractVecOrMat} |
236 |
| - return PhasePoint(θ, r, ℓπ, ℓκ) |
237 |
| -end |
238 |
| - |
239 |
| -# Negative kinetic energy |
240 |
| -#! Eq (13) of Girolami & Calderhead (2011) |
241 |
| -function neg_energy( |
242 |
| - h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T |
243 |
| -) where {T<:AbstractVecOrMat} |
244 |
| - G = h.metric.map(h.metric.G(θ)) |
245 |
| - D = size(G, 1) |
246 |
| - # Need to consider the normalizing term as it is no longer same for different θs |
247 |
| - logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined |
248 |
| - mul!(h.metric._temp, inv(G), r) |
249 |
| - return -logZ - dot(r, h.metric._temp) / 2 |
| 5 | + r::AbstractVecOrMat, |
| 6 | +) |
| 7 | + H = h.metric.G(θ) |
| 8 | + G = h.metric.map(H) |
| 9 | + return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
250 | 10 | end
|
251 | 11 |
|
252 |
| -# QUES L31 of hamiltonian.jl now reads a bit weird (semantically) |
253 | 12 | function ∂H∂θ(
|
254 |
| - h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}}, |
| 13 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic}, |
255 | 14 | θ::AbstractVecOrMat{T},
|
256 | 15 | r::AbstractVecOrMat{T},
|
257 | 16 | ) where {T}
|
@@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat}
|
293 | 52 | end
|
294 | 53 |
|
295 | 54 | function ∂H∂θ(
|
296 |
| - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 55 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
297 | 56 | θ::AbstractVecOrMat{T},
|
298 | 57 | r::AbstractVecOrMat{T},
|
299 | 58 | ) where {T}
|
300 | 59 | return ∂H∂θ_cache(h, θ, r)
|
301 | 60 | end
|
302 | 61 | function ∂H∂θ_cache(
|
303 |
| - h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}}, |
| 62 | + h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic}, |
304 | 63 | θ::AbstractVecOrMat{T},
|
305 | 64 | r::AbstractVecOrMat{T};
|
306 | 65 | return_cache=false,
|
@@ -342,17 +101,26 @@ function ∂H∂θ_cache(
|
342 | 101 | return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
|
343 | 102 | end
|
344 | 103 |
|
345 |
| -#! Eq (14) of Girolami & Calderhead (2011) |
346 |
| -function ∂H∂r( |
347 |
| - h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat |
348 |
| -) |
349 |
| - H = h.metric.G(θ) |
350 |
| - # if !all(isfinite, H) |
351 |
| - # println("θ: ", θ) |
352 |
| - # println("H: ", H) |
353 |
| - # end |
354 |
| - G = h.metric.map(H) |
355 |
| - # return inv(G) * r |
356 |
| - # println("G \ r: ", G \ r) |
357 |
| - return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't |
| 104 | +# QUES Do we want to change everything to position dependent by default? |
| 105 | +# Add θ to ∂H∂r for DenseRiemannianMetric |
| 106 | +function phasepoint( |
| 107 | + h::Hamiltonian{<:DenseRiemannianMetric}, |
| 108 | + θ::T, |
| 109 | + r::T; |
| 110 | + ℓπ=∂H∂θ(h, θ), |
| 111 | + ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)), |
| 112 | +) where {T<:AbstractVecOrMat} |
| 113 | + return PhasePoint(θ, r, ℓπ, ℓκ) |
| 114 | +end |
| 115 | + |
| 116 | +#! Eq (13) of Girolami & Calderhead (2011) |
| 117 | +function neg_energy( |
| 118 | + h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T |
| 119 | +) where {T<:AbstractVecOrMat} |
| 120 | + G = h.metric.map(h.metric.G(θ)) |
| 121 | + D = size(G, 1) |
| 122 | + # Need to consider the normalizing term as it is no longer same for different θs |
| 123 | + logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined |
| 124 | + mul!(h.metric._temp, inv(G), r) |
| 125 | + return -logZ - dot(r, h.metric._temp) / 2 |
358 | 126 | end
|
0 commit comments