Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 0e99f97

Browse files
Merge pull request #58 from JuliaDiffEq/cu2
GPU color autodiff
2 parents 923f8a8 + b9bbc00 commit 0e99f97

File tree

6 files changed

+107
-14
lines changed

6 files changed

+107
-14
lines changed

.gitlab-ci.yml

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
image: "julia:1"
2+
3+
variables:
4+
JULIA_DEPOT_PATH: "$CI_PROJECT_DIR/.julia/"
5+
JULIA_NUM_THREADS: '8'
6+
7+
cache:
8+
paths:
9+
- .julia/
10+
11+
build:
12+
stage: build
13+
tags:
14+
- 'p6000'
15+
script:
16+
- curl https://julialang-s3.julialang.org/bin/linux/x64/1.1/julia-1.1.1-linux-x86_64.tar.gz -o julia.tar.gz
17+
- unp julia.tar.gz
18+
- export PATH="$(pwd)/julia-1.1.1/bin:$PATH"
19+
- julia -e "using InteractiveUtils;
20+
versioninfo()"
21+
- julia --project -e "using Pkg;
22+
Pkg.update();
23+
Pkg.instantiate();
24+
Pkg.add(\"SparseDiffTools\");
25+
pkg\"precompile\";
26+
using SparseDiffTools;"
27+
only:
28+
- master
29+
- tags
30+
- external
31+
- pushes
32+
artifacts:
33+
untracked: true
34+
paths:
35+
- .julia/**/*
36+
- julia-1.1.1/**/*
37+
38+
test-GPU:
39+
stage: test
40+
tags:
41+
- 'p6000'
42+
dependencies:
43+
- build
44+
variables:
45+
GROUP: "GPU"
46+
script:
47+
- export PATH="$(pwd)/julia-1.1.1/bin:$PATH"
48+
- julia -e "using InteractiveUtils;
49+
versioninfo()"
50+
- julia --project -e "using Pkg; Pkg.add(\"CuArrays\");
51+
Pkg.test(\"SparseDiffTools\"; coverage=true);"
52+
only:
53+
- master
54+
- tags
55+
- external
56+
- pushes
57+
artifacts:
58+
untracked: true
59+
paths:
60+
- .julia/**/*
61+
- julia-1.1.1/**/*

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <cont
44
version = "0.6.0"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
79
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
810
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
911
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
@@ -13,7 +15,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1315
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1416
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1517
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
16-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1718

1819
[compat]
1920
ArrayInterface = "1.1"

src/SparseDiffTools.jl

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using ForwardDiff
77
using LightGraphs
88
using Requires
99
using VertexSafeGraphs
10+
using Adapt
1011

1112
using LinearAlgebra
1213
using SparseArrays, ArrayInterface

src/differentiation/compute_jacobian_ad.jl

+7-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ struct ForwardColorJacCache{T,T2,T3,T4,T5,T6}
55
p::T4
66
color::T5
77
sparsity::T6
8+
chunksize::Int
89
end
910

1011
function default_chunk_size(maxcolor)
@@ -29,18 +30,19 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
2930
chunksize = _chunksize
3031
end
3132

32-
t = zeros(Dual{typeof(f), eltype(x), getsize(chunksize)},length(x))
33+
p = adapt.(typeof(x),generate_chunked_partials(x,color,chunksize))
34+
t = Dual{typeof(f)}.(x,first(p))
3335

3436
if dx === nothing
3537
fx = similar(t)
3638
_dx = similar(x)
3739
else
38-
fx = zeros(Dual{typeof(f), eltype(dx), getsize(chunksize)},length(dx))
40+
fx = Dual{typeof(f)}.(dx,first(p))
3941
_dx = dx
4042
end
4143

42-
p = generate_chunked_partials(x,color,chunksize)
43-
ForwardColorJacCache(t,fx,_dx,p,color,sparsity)
44+
45+
ForwardColorJacCache(t,fx,_dx,p,color,sparsity,getsize(chunksize))
4446
end
4547

4648
generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N))
@@ -96,8 +98,8 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
9698
p = jac_cache.p
9799
color = jac_cache.color
98100
sparsity = jac_cache.sparsity
101+
chunksize = jac_cache.chunksize
99102
color_i = 1
100-
chunksize = length(first(first(jac_cache.p)))
101103
fill!(J, zero(eltype(J)))
102104

103105
for i in eachindex(p)

test/runtests.jl

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
using SafeTestsets
22

3-
@time @safetestset "Exact coloring via contraction" begin include("test_contraction.jl") end
4-
@time @safetestset "Greedy distance-1 coloring" begin include("test_greedy_d1.jl") end
5-
@time @safetestset "Greedy star coloring" begin include("test_greedy_star.jl") end
6-
@time @safetestset "Matrix to graph conversion" begin include("test_matrix2graph.jl") end
7-
@time @safetestset "AD using color vector" begin include("test_ad.jl") end
8-
@time @safetestset "Integration test" begin include("test_integration.jl") end
9-
@time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
10-
@time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
3+
const GROUP = get(ENV, "GROUP", "All")
4+
const is_APPVEYOR = ( Sys.iswindows() && haskey(ENV,"APPVEYOR") )
5+
const is_TRAVIS = haskey(ENV,"TRAVIS")
6+
7+
if GROUP == "All"
8+
@time @safetestset "Exact coloring via contraction" begin include("test_contraction.jl") end
9+
@time @safetestset "Greedy distance-1 coloring" begin include("test_greedy_d1.jl") end
10+
@time @safetestset "Greedy star coloring" begin include("test_greedy_star.jl") end
11+
@time @safetestset "Matrix to graph conversion" begin include("test_matrix2graph.jl") end
12+
@time @safetestset "AD using color vector" begin include("test_ad.jl") end
13+
@time @safetestset "Integration test" begin include("test_integration.jl") end
14+
@time @safetestset "Special matrices" begin include("test_specialmatrices.jl") end
15+
@time @safetestset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end
16+
end
17+
18+
if GROUP == "GPU"
19+
@time @safetestset "GPU AD" begin include("test_gpu_ad.jl") end
20+
end

test/test_gpu_ad.jl

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using SparseDiffTools, CuArrays, Test, LinearAlgebra
2+
using ArrayInterface: allowed_getindex, allowed_setindex!
3+
function f(dx,x)
4+
dx[2:end-1] = x[1:end-2] - 2x[2:end-1] + x[3:end]
5+
allowed_setindex!(dx,-2allowed_getindex(x,1) + allowed_getindex(x,2),1)
6+
allowed_setindex!(dx,-2allowed_getindex(x,30) + allowed_getindex(x,29),30)
7+
nothing
8+
end
9+
10+
_J1 = similar(rand(30,30))
11+
_denseJ1 = cu(collect(_J1))
12+
x = cu(rand(30))
13+
CuArrays.allowscalar(false)
14+
forwarddiff_color_jacobian!(_denseJ1, f, x)
15+
forwarddiff_color_jacobian!(_denseJ1, f, x, sparsity = _J1)
16+
forwarddiff_color_jacobian!(_denseJ1, f, x, color = repeat(1:3,10), sparsity = _J1)
17+
_Jt = similar(Tridiagonal(_J1))
18+
@test_broken forwarddiff_color_jacobian!(_denseJ1, f, x, color = repeat(1:3,10), sparsity = _Jt)

0 commit comments

Comments
 (0)