From c5db6f683ce3306573183bf15e950df193b8af2b Mon Sep 17 00:00:00 2001
From: James Schloss <jrs.schloss@gmail.com>
Date: Mon, 13 Jun 2022 18:42:24 +0200
Subject: [PATCH] adding atomics via CUDA and Core.Intrinsics

---
 lib/CUDAKernels/src/CUDAKernels.jl |  26 ++++
 src/KernelAbstractions.jl          |   4 +
 src/atomics.jl                     | 203 ++++++++++++++++++++++++++++
 src/cpu.jl                         |  27 ++++
 test/atomic_test.jl                | 207 +++++++++++++++++++++++++++++
 test/testsuite.jl                  |   8 ++
 6 files changed, 475 insertions(+)
 create mode 100644 src/atomics.jl
 create mode 100644 test/atomic_test.jl

diff --git a/lib/CUDAKernels/src/CUDAKernels.jl b/lib/CUDAKernels/src/CUDAKernels.jl
index a9dae97ac..bb03e4ac9 100644
--- a/lib/CUDAKernels/src/CUDAKernels.jl
+++ b/lib/CUDAKernels/src/CUDAKernels.jl
@@ -359,6 +359,7 @@ else
 end
 
 import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize, __size
+import KernelAbstractions: atomic_add!, atomic_and!, atomic_cas!, atomic_dec!, atomic_inc!, atomic_max!, atomic_min!, atomic_op!, atomic_or!, atomic_sub!, atomic_xchg!, atomic_xor!
 
 ###
 # GPU implementation of shared memory
@@ -395,4 +396,29 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
 # Argument conversion
 KernelAbstractions.argconvert(k::Kernel{CUDADevice}, arg) = CUDA.cudaconvert(arg)
 
+
+###
+# GPU implementation of atomics
+###
+
+afxs = Dict(
+    atomic_add! => CUDA.atomic_add!,
+    atomic_and! => CUDA.atomic_and!,
+    atomic_cas! => CUDA.atomic_cas!,
+    atomic_dec! => CUDA.atomic_dec!,
+    atomic_inc! => CUDA.atomic_inc!,
+    atomic_max! => CUDA.atomic_max!,
+    atomic_min! => CUDA.atomic_min!,
+    atomic_op! => CUDA.atomic_op!,
+    atomic_or! => CUDA.atomic_or!,
+    atomic_sub! => CUDA.atomic_sub!,
+    atomic_xchg! => CUDA.atomic_xchg!,
+    atomic_xor! => CUDA.atomic_xor!
+)
+
+for (afx, cfx) in afxs
+    @device_override @inline function afx(args...)
+        cfx(args...)
+    end
+end
 end
diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl
index 1066375f2..c8dffb372 100644
--- a/src/KernelAbstractions.jl
+++ b/src/KernelAbstractions.jl
@@ -496,6 +496,10 @@ include("extras/extras.jl")
 
 include("reflection.jl")
 
+# Atomics
+
+include("atomics.jl")
+
 # CPU backend
 
 include("cpu.jl")
diff --git a/src/atomics.jl b/src/atomics.jl
new file mode 100644
index 000000000..4913c0d51
--- /dev/null
+++ b/src/atomics.jl
@@ -0,0 +1,203 @@
+###
+# Atomics
+###
+
+export atomic_add!, atomic_sub!, atomic_and!, atomic_or!, atomic_xor!,
+       atomic_min!, atomic_max!, atomic_inc!, atomic_dec!, atomic_xchg!,
+       atomic_op!, atomic_cas!
+
+# helper functions for inc(rement) and dec(rement)
+function dec(a::T,b::T) where T
+    ((a == 0) | (a > b)) ? b : (a-T(1))
+end
+
+function inc(a::T,b::T) where T
+    (a >= b) ? T(0) : (a+T(1))
+end
+
+# arithmetic, bitwise, min/max, and inc/dec operations
+const ops = Dict(
+    :atomic_add!   => +,
+    :atomic_sub!   => -,
+    :atomic_and!   => &,
+    :atomic_or!    => |,
+    :atomic_xor!   => ⊻,
+    :atomic_min!   => min,
+    :atomic_max!   => max,
+    :atomic_inc!   => inc,
+    :atomic_dec!   => dec,
+)
+
+# Note: the type T prevents type convertion (for example, Float32 -> 64)
+#       can lead to errors if b is chosen to be of a different, compatible type
+for (name, op) in ops
+    @eval @inline function $name(ptr::Ptr{T}, b::T) where T
+        Core.Intrinsics.atomic_pointermodify(ptr::Ptr{T}, $op, b::T, :monotonic)
+    end
+end
+
+"""
+    atomic_cas!(ptr::Ptr{T}, cmp::T, val::T)
+
+This is an atomic Compare And Swap (CAS).
+It reads the value `old` located at address `ptr` and compare with `cmp`.
+If `old` equals `cmp`, it stores `val` at the same address.
+Otherwise, doesn't change the value `old`.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Additionally, on GPU hardware with compute capability 7.0+, values of type UInt16 are supported.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+function atomic_cas!(ptr::Ptr{T}, old::T, new::T) where T
+    Core.Intrinsics.atomic_pointerreplace(ptr, old, new, :acquire_release, :monotonic)
+end
+
+"""
+    atomic_xchg!(ptr::Ptr{T}, val::T)
+
+This is an atomic exchange.
+It reads the value `old` located at address `ptr` and stores `val` at the same address.
+These operations are performed in one atomic transaction. The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+function atomic_xchg!(ptr::Ptr{T}, b::T) where T
+    Core.Intrinsics.atomic_pointerswap(ptr::Ptr{T}, b::T, :monotonic)
+end
+
+"""
+    atomic_op!(ptr::Ptr{T}, val::T)
+
+This is an arbitrary atomic operation.
+It reads the value `old` located at address `ptr` and uses `val` in the operation `op` (defined elsewhere)
+These operations are performed in one atomic transaction. The function returns `old`.
+
+This function is somewhat experimental.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+function atomic_op!(ptr::Ptr{T}, op, b::T) where T
+    Core.Intrinsics.atomic_pointermodify(ptr::Ptr{T}, op, b::T, :monotonic)
+end
+
+# Other Documentation
+
+"""
+    atomic_add!(ptr::Ptr{T}, val::T)
+
+This is an atomic addition.
+It reads the value `old` located at address `ptr`, computes `old + val`, and stores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32, UInt64, and Float32.
+Additionally, on GPU hardware with compute capability 6.0+, values of type Float64 are supported.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_add!
+
+"""
+    atomic_sub!(ptr::Ptr{T}, val::T)
+
+This is an atomic subtraction.
+It reads the value `old` located at address `ptr`, computes `old - val`, and stores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_sub!
+
+"""
+    atomic_and!(ptr::Ptr{T}, val::T)
+
+This is an atomic and.
+It reads the value `old` located at address `ptr`, computes `old & val`, and stores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_and!
+
+"""
+    atomic_or!(ptr::Ptr{T}, val::T)
+
+This is an atomic or.
+It reads the value `old` located at address `ptr`, computes `old | val`, and stores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_or!
+
+"""
+    atomic_xor!(ptr::Ptr{T}, val::T)
+
+This is an atomic xor.
+It reads the value `old` located at address `ptr`, computes `old ⊻ val`, and stores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_xor!
+
+"""
+    atomic_min!(ptr::Ptr{T}, val::T)
+
+This is an atomic min.
+It reads the value `old` located at address `ptr`, computes `min(old, val)`, and st ores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_min!
+
+"""
+    atomic_max!(ptr::Ptr{T}, val::T)
+
+This is an atomic max.
+It reads the value `old` located at address `ptr`, computes `max(old, val)`, and st ores the result back to memory at the same address.
+These operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is supported for values of type Int32, Int64, UInt32 and UInt64.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_max!
+
+"""
+    atomic_inc!(ptr::Ptr{T}, val::T)
+
+This is an atomic increment function that counts up to a certain number before starting again at 0.
+It reads the value `old` located at address `ptr`, computes `((old >= val) ? 0 : (o ld+1))`, and stores the result back to memory at the same address.
+These three operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is only supported for values of type Int32.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_inc!
+
+"""
+    atomic_dec!(ptr::Ptr{T}, val::T)
+
+This is an atomic decrement function that counts down to 0 from a defined value `val`.
+It reads the value `old` located at address `ptr`, computes `(((old == 0) | (old > val)) ? val : (old-1))`, and stores the result back to memory at the same address.
+These three operations are performed in one atomic transaction.
+The function returns `old`.
+
+This operation is only supported for values of type Int32.
+Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above.
+"""
+atomic_dec!
diff --git a/src/cpu.jl b/src/cpu.jl
index 802918978..186ec948b 100644
--- a/src/cpu.jl
+++ b/src/cpu.jl
@@ -234,3 +234,30 @@ end
 
 # Argument conversion
 KernelAbstractions.argconvert(k::Kernel{CPU}, arg) = arg
+
+###
+# CPU error handling if under 1.7
+###
+
+if Base.VERSION < v"1.7.0"
+
+    import KernelAbstractions: atomic_add!, atomic_and!, atomic_cas!,
+                               atomic_dec!, atomic_inc!, atomic_max!,
+                               atomic_min!, atomic_op!, atomic_or!,
+                               atomic_sub!, atomic_xchg!, atomic_xor!
+
+    function atomic_error(args...)
+        error("CPU Atomics are not allowed for julia version under 1.7!")
+    end
+
+    afxs = [atomic_add!, atomic_and!, atomic_cas!, atomic_dec!,
+            atomic_inc!, atomic_max!, atomic_min!, atomic_op!,
+            atomic_or!, atomic_sub!, atomic_xchg!, atomic_xor!]
+
+    for afx in afxs
+        @inline function afx(ctx, idx::CartesianIndex)
+            atomic_error(args...)
+        end
+    end
+end
+
diff --git a/test/atomic_test.jl b/test/atomic_test.jl
new file mode 100644
index 000000000..68eed69fc
--- /dev/null
+++ b/test/atomic_test.jl
@@ -0,0 +1,207 @@
+using KernelAbstractions, Test
+
+# Note: kernels affect second element because some CPU defaults will affect the
+#       first element of a pointer if not specified, so I am covering the bases
+@kernel function atomic_add_kernel!(input, b)
+    atomic_add!(pointer(input,2),b) 
+end
+
+@kernel function atomic_sub_kernel!(input, b)
+    atomic_sub!(pointer(input,2),b) 
+end
+
+@kernel function atomic_inc_kernel!(input, b)
+    atomic_inc!(pointer(input,2),b) 
+end
+
+@kernel function atomic_dec_kernel!(input, b)
+    atomic_dec!(pointer(input,2),b) 
+end
+
+@kernel function atomic_xchg_kernel!(input, b)
+    atomic_xchg!(pointer(input,2),b) 
+end
+
+@kernel function atomic_and_kernel!(input, b)
+    tid = @index(Global)
+    atomic_and!(pointer(input),b[tid]) 
+end
+
+@kernel function atomic_or_kernel!(input, b)
+    tid = @index(Global)
+    atomic_or!(pointer(input),b[tid]) 
+end
+
+@kernel function atomic_xor_kernel!(input, b)
+    tid = @index(Global)
+    atomic_xor!(pointer(input),b[tid]) 
+end
+
+@kernel function atomic_max_kernel!(input, b)
+    tid = @index(Global)
+    atomic_max!(pointer(input,2), b[tid]) 
+end
+
+@kernel function atomic_min_kernel!(input, b)
+    tid = @index(Global)
+    atomic_min!(pointer(input,2), b[tid]) 
+end
+
+@kernel function atomic_cas_kernel!(input, b, c)
+    atomic_cas!(pointer(input,2),b,c) 
+end
+
+function atomics_testsuite(backend, ArrayT)
+
+    @testset "atomic addition tests" begin
+        types = [Int32, Int64, UInt32, UInt64, Float32]
+
+        for T in types
+            A = ArrayT{T}([0,0])
+
+            kernel! = atomic_add_kernel!(backend(), 4)
+            wait(kernel!(A, one(T), ndrange=(1024)))
+
+            @test Array(A)[2] == 1024
+        end
+    end
+
+    @testset "atomic subtraction tests" begin
+        types = [Int32, Int64, UInt32, UInt64, Float32]
+
+        for T in types
+            A = ArrayT{T}([2048,2048])
+
+            kernel! = atomic_sub_kernel!(backend(), 4)
+            wait(kernel!(A, one(T), ndrange=(1024)))
+
+            @test Array(A)[2] == 1024
+        end
+    end
+
+    @testset "atomic inc tests" begin
+        types = [Int32]
+
+        for T in types
+            A = ArrayT{T}([0,0])
+
+            kernel! = atomic_inc_kernel!(backend(), 4)
+            wait(kernel!(A, T(512), ndrange=(768)))
+
+            @test Array(A)[2] == 255
+        end
+    end
+
+    @testset "atomic dec tests" begin
+        types = [Int32]
+
+        for T in types
+            A = ArrayT{T}([1024,1024])
+
+            kernel! = atomic_dec_kernel!(backend(), 4)
+            wait(kernel!(A, T(512), ndrange=(256)))
+
+            @test Array(A)[2] == 257
+        end
+    end
+
+    @testset "atomic xchg tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([0,0])
+
+            kernel! = atomic_xchg_kernel!(backend(), 4)
+            wait(kernel!(A, T(1), ndrange=(256)))
+
+            @test Array(A)[2] == one(T)
+        end
+    end
+
+    @testset "atomic and tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([1023])
+            B = ArrayT{T}([1023-2^(i-1) for i = 1:10])
+
+            kernel! = atomic_and_kernel!(backend(), 4)
+            wait(kernel!(A, B, ndrange=length(B)))
+    
+            @test Array(A)[1] == zero(T)
+        end
+    end
+
+    @testset "atomic or tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([0])
+            B = ArrayT{T}([2^(i-1) for i = 1:10])
+
+            kernel! = atomic_or_kernel!(backend(), 4)
+            wait(kernel!(A, B, ndrange=length(B)))
+    
+            @test Array(A)[1] == T(1023)
+        end
+    end
+
+    @testset "atomic xor tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([1023])
+            B = ArrayT{T}([2^(i-1) for i = 1:10])
+
+            kernel! = atomic_xor_kernel!(backend(), 4)
+            wait(kernel!(A, B, ndrange=length(B)))
+    
+            @test Array(A)[1] == T(0)
+        end
+    end
+
+    @testset "atomic max tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([0,0])
+            B = ArrayT{T}([i for i = 1:1024])
+
+            kernel! = atomic_max_kernel!(backend(), 4)
+            wait(kernel!(A, B, ndrange=length(B)))
+    
+            @test Array(A)[2] == T(1024)
+        end
+    end
+
+    @testset "atomic min tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([1024,1024])
+            B = ArrayT{T}([i for i = 1:1024])
+
+            kernel! = atomic_min_kernel!(backend(), 4)
+            wait(kernel!(A, B, ndrange=length(B)))
+    
+            @test Array(A)[2] == T(1)
+        end
+    end
+
+
+    @testset "atomic cas tests" begin
+        types = [Int32, Int64, UInt32, UInt64]
+
+        for T in types
+            A = ArrayT{T}([0,0])
+
+            kernel! = atomic_cas_kernel!(backend(), 4)
+            wait(kernel!(A, zero(T), one(T), ndrange=(1024)))
+
+            @test Array(A)[2] == 1
+        end
+    end
+
+
+
+end
diff --git a/test/testsuite.jl b/test/testsuite.jl
index b0a81857b..0e92b01d3 100644
--- a/test/testsuite.jl
+++ b/test/testsuite.jl
@@ -15,6 +15,7 @@ include("compiler.jl")
 include("reflection.jl")
 include("examples.jl")
 include("convert.jl")
+include("atomic_test.jl")
 
 function testsuite(backend, backend_str, backend_mod, AT, DAT)
     @testset "Unittests" begin
@@ -69,6 +70,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT)
         convert_testsuite(backend, AT)
     end
 
+    if backend_str != "ROCM" &&
+      !(backend_str == "CPU" && Base.VERSION < v"1.7.0")
+        @testset "Atomics" begin
+            atomics_testsuite(backend, AT)
+        end
+    end
+
     if backend_str == "CUDA" || backend_str == "ROCM"
         @testset "Examples" begin
             examples_testsuite(backend_str)