Skip to content

Commit 5129015

Browse files
committed
Copy changes
1 parent 479c18e commit 5129015

File tree

1 file changed

+33
-265
lines changed

1 file changed

+33
-265
lines changed

src/riemannian/hamiltonian.jl

Lines changed: 33 additions & 265 deletions
Original file line numberDiff line numberDiff line change
@@ -1,257 +1,16 @@
1-
using Random
2-
3-
### integrator.jl
4-
5-
import AdvancedHMC: ∂H∂θ, ∂H∂r, DualValue, PhasePoint, phasepoint, step
6-
using AdvancedHMC: TYPEDEF, TYPEDFIELDS, AbstractScalarOrVec, AbstractLeapfrog, step_size
7-
8-
"""
9-
$(TYPEDEF)
10-
11-
Generalized leapfrog integrator with fixed step size `ϵ`.
12-
13-
# Fields
14-
15-
$(TYPEDFIELDS)
16-
"""
17-
struct GeneralizedLeapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
18-
"Step size."
19-
ϵ::T
20-
n::Int
21-
end
22-
function Base.show(io::IO, l::GeneralizedLeapfrog)
23-
return print(io, "GeneralizedLeapfrog(ϵ=$(round.(l.ϵ; sigdigits=3)), n=$(l.n))")
24-
end
25-
26-
# Fallback to ignore return_cache & cache kwargs for other ∂H∂θ
27-
function ∂H∂θ_cache(h, θ, r; return_cache=false, cache=nothing) where {T}
28-
dv = ∂H∂θ(h, θ, r)
29-
return return_cache ? (dv, nothing) : dv
30-
end
31-
32-
# TODO Make sure vectorization works
33-
# TODO Check if tempering is valid
34-
function step(
35-
lf::GeneralizedLeapfrog{T},
36-
h::Hamiltonian,
37-
z::P,
38-
n_steps::Int=1;
39-
fwd::Bool=n_steps > 0, # simulate hamiltonian backward when n_steps < 0
40-
full_trajectory::Val{FullTraj}=Val(false),
41-
) where {T<:AbstractScalarOrVec{<:AbstractFloat},P<:PhasePoint,FullTraj}
42-
n_steps = abs(n_steps) # to support `n_steps < 0` cases
43-
44-
ϵ = fwd ? step_size(lf) : -step_size(lf)
45-
ϵ = ϵ'
46-
47-
res = if FullTraj
48-
Vector{P}(undef, n_steps)
49-
else
50-
z
51-
end
52-
53-
for i in 1:n_steps
54-
θ_init, r_init = z.θ, z.r
55-
# Tempering
56-
#r = temper(lf, r, (i=i, is_half=true), n_steps)
57-
#! Eq (16) of Girolami & Calderhead (2011)
58-
r_half = copy(r_init)
59-
local cache
60-
for j in 1:(lf.n)
61-
# Reuse cache for the first iteration
62-
if j == 1
63-
(; value, gradient) = z.ℓπ
64-
elseif j == 2 # cache intermediate values that depends on θ only (which are unchanged)
65-
retval, cache = ∂H∂θ_cache(h, θ_init, r_half; return_cache=true)
66-
(; value, gradient) = retval
67-
else # reuse cache
68-
(; value, gradient) = ∂H∂θ_cache(h, θ_init, r_half; cache=cache)
69-
end
70-
r_half = r_init - ϵ / 2 * gradient
71-
# println("r_half: ", r_half)
72-
end
73-
#! Eq (17) of Girolami & Calderhead (2011)
74-
θ_full = copy(θ_init)
75-
term_1 = ∂H∂r(h, θ_init, r_half) # unchanged across the loop
76-
for j in 1:(lf.n)
77-
θ_full = θ_init + ϵ / 2 * (term_1 + ∂H∂r(h, θ_full, r_half))
78-
# println("θ_full :", θ_full)
79-
end
80-
#! Eq (18) of Girolami & Calderhead (2011)
81-
(; value, gradient) = ∂H∂θ(h, θ_full, r_half)
82-
r_full = r_half - ϵ / 2 * gradient
83-
# println("r_full: ", r_full)
84-
# Tempering
85-
#r = temper(lf, r, (i=i, is_half=false), n_steps)
86-
# Create a new phase point by caching the logdensity and gradient
87-
z = phasepoint(h, θ_full, r_full; ℓπ=DualValue(value, gradient))
88-
# Update result
89-
if FullTraj
90-
res[i] = z
91-
else
92-
res = z
93-
end
94-
if !isfinite(z)
95-
# Remove undef
96-
if FullTraj
97-
res = res[isassigned.(Ref(res), 1:n_steps)]
98-
end
99-
break
100-
end
101-
# @assert false
102-
end
103-
return res
104-
end
105-
106-
# TODO Make the order of θ and r consistent with neg_energy
107-
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
108-
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)
109-
110-
### hamiltonian.jl
111-
112-
import AdvancedHMC: refresh, phasepoint
113-
using AdvancedHMC: FullMomentumRefreshment, PartialMomentumRefreshment, AbstractMetric
114-
115-
# To change L180 of hamiltonian.jl
116-
function phasepoint(
117-
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
118-
θ::AbstractVecOrMat{T},
119-
h::Hamiltonian,
120-
) where {T<:Real}
121-
return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
122-
end
123-
124-
# To change L191 of hamiltonian.jl
125-
function refresh(
126-
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
127-
::FullMomentumRefreshment,
128-
h::Hamiltonian,
129-
z::PhasePoint,
130-
)
131-
return phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
132-
end
133-
134-
# To change L215 of hamiltonian.jl
135-
function refresh(
136-
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
137-
ref::PartialMomentumRefreshment,
138-
h::Hamiltonian,
139-
z::PhasePoint,
140-
)
141-
return phasepoint(
142-
h,
143-
z.θ,
144-
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
145-
)
146-
end
147-
148-
### metric.jl
149-
150-
import AdvancedHMC: _rand
151-
using AdvancedHMC: AbstractMetric
152-
using LinearAlgebra: eigen, cholesky, Symmetric
153-
154-
abstract type AbstractRiemannianMetric <: AbstractMetric end
155-
156-
abstract type AbstractHessianMap end
157-
158-
struct IdentityMap <: AbstractHessianMap end
159-
160-
(::IdentityMap)(x) = x
161-
162-
struct SoftAbsMap{T} <: AbstractHessianMap
163-
α::T
164-
end
165-
166-
# TODO Register softabs with ReverseDiff
167-
#! The definition of SoftAbs from Page 3 of Betancourt (2012)
168-
function softabs(X, α=20.0)
169-
F = eigen(X) # ReverseDiff cannot diff through `eigen`
170-
Q = hcat(F.vectors)
171-
λ = F.values
172-
softabsλ = λ .* coth.(α * λ)
173-
return Q * diagm(softabsλ) * Q', Q, λ, softabsλ
174-
end
175-
176-
(map::SoftAbsMap)(x) = softabs(x, map.α)[1]
177-
178-
struct DenseRiemannianMetric{
179-
T,
180-
TM<:AbstractHessianMap,
181-
A<:Union{Tuple{Int},Tuple{Int,Int}},
182-
AV<:AbstractVecOrMat{T},
183-
TG,
184-
T∂G∂θ,
185-
} <: AbstractRiemannianMetric
186-
size::A
187-
G::TG # TODO store G⁻¹ here instead
188-
∂G∂θ::T∂G∂θ
189-
map::TM
190-
_temp::AV
191-
end
192-
193-
# TODO Make dense mass matrix support matrix-mode parallel
194-
function DenseRiemannianMetric(size, G, ∂G∂θ, map=IdentityMap()) where {T<:AbstractFloat}
195-
_temp = Vector{Float64}(undef, size[1])
196-
return DenseRiemannianMetric(size, G, ∂G∂θ, map, _temp)
197-
end
198-
# DenseEuclideanMetric(::Type{T}, D::Int) where {T} = DenseEuclideanMetric(Matrix{T}(I, D, D))
199-
# DenseEuclideanMetric(D::Int) = DenseEuclideanMetric(Float64, D)
200-
# DenseEuclideanMetric(::Type{T}, sz::Tuple{Int}) where {T} = DenseEuclideanMetric(Matrix{T}(I, first(sz), first(sz)))
201-
# DenseEuclideanMetric(sz::Tuple{Int}) = DenseEuclideanMetric(Float64, sz)
202-
203-
# renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)
204-
205-
Base.size(e::DenseRiemannianMetric) = e.size
206-
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
207-
Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)")
208-
209-
function rand_momentum(
210-
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
211-
metric::DenseRiemannianMetric{T},
212-
kinetic,
1+
#! Eq (14) of Girolami & Calderhead (2011)
2+
function ∂H∂r(
3+
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic},
2134
θ::AbstractVecOrMat,
214-
) where {T}
215-
r = _randn(rng, T, size(metric)...)
216-
G⁻¹ = inv(metric.map(metric.G(θ)))
217-
chol = cholesky(Symmetric(G⁻¹))
218-
ldiv!(chol.U, r)
219-
return r
220-
end
221-
222-
### hamiltonian.jl
223-
224-
import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r
225-
using LinearAlgebra: logabsdet, tr
226-
227-
# QUES Do we want to change everything to position dependent by default?
228-
# Add θ to ∂H∂r for DenseRiemannianMetric
229-
function phasepoint(
230-
h::Hamiltonian{<:DenseRiemannianMetric},
231-
θ::T,
232-
r::T;
233-
ℓπ=∂H∂θ(h, θ),
234-
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
235-
) where {T<:AbstractVecOrMat}
236-
return PhasePoint(θ, r, ℓπ, ℓκ)
237-
end
238-
239-
# Negative kinetic energy
240-
#! Eq (13) of Girolami & Calderhead (2011)
241-
function neg_energy(
242-
h::Hamiltonian{<:DenseRiemannianMetric}, r::T, θ::T
243-
) where {T<:AbstractVecOrMat}
244-
G = h.metric.map(h.metric.G(θ))
245-
D = size(G, 1)
246-
# Need to consider the normalizing term as it is no longer same for different θs
247-
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
248-
mul!(h.metric._temp, inv(G), r)
249-
return -logZ - dot(r, h.metric._temp) / 2
5+
r::AbstractVecOrMat,
6+
)
7+
H = h.metric.G(θ)
8+
G = h.metric.map(H)
9+
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
25010
end
25111

252-
# QUES L31 of hamiltonian.jl now reads a bit weird (semantically)
25312
function ∂H∂θ(
254-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap}},
13+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:IdentityMap},<:GaussianKinetic},
25514
θ::AbstractVecOrMat{T},
25615
r::AbstractVecOrMat{T},
25716
) where {T}
@@ -293,14 +52,14 @@ function make_J(λ::AbstractVector{T}, α::T) where {T<:AbstractFloat}
29352
end
29453

29554
function ∂H∂θ(
296-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
55+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
29756
θ::AbstractVecOrMat{T},
29857
r::AbstractVecOrMat{T},
29958
) where {T}
30059
return ∂H∂θ_cache(h, θ, r)
30160
end
30261
function ∂H∂θ_cache(
303-
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap}},
62+
h::Hamiltonian{<:DenseRiemannianMetric{T,<:SoftAbsMap},<:GaussianKinetic},
30463
θ::AbstractVecOrMat{T},
30564
r::AbstractVecOrMat{T};
30665
return_cache=false,
@@ -342,17 +101,26 @@ function ∂H∂θ_cache(
342101
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv
343102
end
344103

345-
#! Eq (14) of Girolami & Calderhead (2011)
346-
function ∂H∂r(
347-
h::Hamiltonian{<:DenseRiemannianMetric}, θ::AbstractVecOrMat, r::AbstractVecOrMat
348-
)
349-
H = h.metric.G(θ)
350-
# if !all(isfinite, H)
351-
# println("θ: ", θ)
352-
# println("H: ", H)
353-
# end
354-
G = h.metric.map(H)
355-
# return inv(G) * r
356-
# println("G \ r: ", G \ r)
357-
return G \ r # NOTE it's actually pretty weird that ∂H∂θ returns DualValue but ∂H∂r doesn't
104+
# QUES Do we want to change everything to position dependent by default?
105+
# Add θ to ∂H∂r for DenseRiemannianMetric
106+
function phasepoint(
107+
h::Hamiltonian{<:DenseRiemannianMetric},
108+
θ::T,
109+
r::T;
110+
ℓπ=∂H∂θ(h, θ),
111+
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, θ, r)),
112+
) where {T<:AbstractVecOrMat}
113+
return PhasePoint(θ, r, ℓπ, ℓκ)
114+
end
115+
116+
#! Eq (13) of Girolami & Calderhead (2011)
117+
function neg_energy(
118+
h::Hamiltonian{<:DenseRiemannianMetric,<:GaussianKinetic}, r::T, θ::T
119+
) where {T<:AbstractVecOrMat}
120+
G = h.metric.map(h.metric.G(θ))
121+
D = size(G, 1)
122+
# Need to consider the normalizing term as it is no longer same for different θs
123+
logZ = 1 / 2 * (D * log(2π) + logdet(G)) # it will be user's responsibility to make sure G is SPD and logdet(G) is defined
124+
mul!(h.metric._temp, inv(G), r)
125+
return -logZ - dot(r, h.metric._temp) / 2
358126
end

0 commit comments

Comments
 (0)