diff --git a/examples/histogram.jl b/examples/histogram.jl new file mode 100644 index 000000000..17943872d --- /dev/null +++ b/examples/histogram.jl @@ -0,0 +1,119 @@ +using KernelAbstractions, Test +include(joinpath(@__DIR__, "utils.jl")) # Load backend + + +# Function to use as a baseline for CPU metrics +function create_histogram(input) + histogram_output = zeros(Int, maximum(input)) + for i = 1:length(input) + histogram_output[input[i]] += 1 + end + return histogram_output +end + +# This a 1D histogram kernel where the histogramming happens on shmem +@kernel function histogram_kernel!(histogram_output, input) + tid = @index(Global, Linear) + lid = @index(Local, Linear) + + @uniform warpsize = Int(32) + + @uniform gs = @groupsize()[1] + @uniform N = length(histogram_output) + + shared_histogram = @localmem Int (gs) + + # This will go through all input elements and assign them to a location in + # shmem. Note that if there is not enough shem, we create different shmem + # blocks to write to. For example, if shmem is of size 256, but it's + # possible to get a value of 312, then we will have 2 separate shmem blocks, + # one from 1->256, and another from 256->512 + @uniform max_element = 1 + for min_element = 1:gs:N + + # Setting shared_histogram to 0 + @inbounds shared_histogram[lid] = 0 + @synchronize() + + max_element = min_element + gs + if max_element > N + max_element = N+1 + end + + # Defining bin on shared memory and writing to it if possible + bin = input[tid] + if bin >= min_element && bin < max_element + bin -= min_element-1 + GC.@preserve shared_histogram begin + atomic_add!(pointer(shared_histogram, bin), Int(1)) + end + end + + @synchronize() + + if ((lid+min_element-1) <= N) + atomic_add!(pointer(histogram_output, lid+min_element-1), + shared_histogram[lid]) + end + + end + +end + +function histogram!(histogram_output, input; + numcores = 4, numthreads = 256) + + if isa(input, Array) + kernel! = histogram_kernel!(CPU(), numcores) + else + kernel! = histogram_kernel!(CUDADevice(), numthreads) + end + + kernel!(histogram_output, input, ndrange=size(input)) +end + +@testset "histogram tests" begin + + rand_input = [rand(1:128) for i = 1:1000] + linear_input = [i for i = 1:1024] + all_2 = [2 for i = 1:512] + + histogram_rand_baseline = create_histogram(rand_input) + histogram_linear_baseline = create_histogram(linear_input) + histogram_2_baseline = create_histogram(all_2) + + if Base.VERSION >= v"1.7.0" + CPU_rand_histogram = zeros(Int, 128) + CPU_linear_histogram = zeros(Int, 1024) + CPU_2_histogram = zeros(Int, 2) + + wait(histogram!(CPU_rand_histogram, rand_input)) + wait(histogram!(CPU_linear_histogram, linear_input)) + wait(histogram!(CPU_2_histogram, all_2)) + + @test isapprox(CPU_rand_histogram, histogram_rand_baseline) + @test isapprox(CPU_linear_histogram, histogram_linear_baseline) + @test isapprox(CPU_2_histogram, histogram_2_baseline) + end + + if has_cuda_gpu() + CUDA.allowscalar(false) + + GPU_rand_input = CuArray(rand_input) + GPU_linear_input = CuArray(linear_input) + GPU_2_input = CuArray(all_2) + + GPU_rand_histogram = CuArray(zeros(Int, 128)) + GPU_linear_histogram = CuArray(zeros(Int, 1024)) + GPU_2_histogram = CuArray(zeros(Int, 2)) + + wait(histogram!(GPU_rand_histogram, GPU_rand_input)) + wait(histogram!(GPU_linear_histogram, GPU_linear_input)) + wait(histogram!(GPU_2_histogram, GPU_2_input)) + + @test isapprox(Array(GPU_rand_histogram), histogram_rand_baseline) + @test isapprox(Array(GPU_linear_histogram), histogram_linear_baseline) + @test isapprox(Array(GPU_2_histogram), histogram_2_baseline) + end + +end diff --git a/lib/CUDAKernels/src/CUDAKernels.jl b/lib/CUDAKernels/src/CUDAKernels.jl index f503b9e11..7a227ed31 100644 --- a/lib/CUDAKernels/src/CUDAKernels.jl +++ b/lib/CUDAKernels/src/CUDAKernels.jl @@ -320,6 +320,7 @@ else end import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize, __size +import KernelAbstractions: atomic_add!, atomic_and!, atomic_cas!, atomic_dec!, atomic_inc!, atomic_max!, atomic_min!, atomic_op!, atomic_or!, atomic_sub!, atomic_xchg!, atomic_xor! ### # GPU implementation of shared memory @@ -381,4 +382,29 @@ end CUDA.ptx_isa_version(args...) end +### +# GPU implementation of atomics +### + +afxs = Dict( + atomic_add! => CUDA.atomic_add!, + atomic_and! => CUDA.atomic_and!, + atomic_cas! => CUDA.atomic_cas!, + atomic_dec! => CUDA.atomic_dec!, + atomic_inc! => CUDA.atomic_inc!, + atomic_max! => CUDA.atomic_max!, + atomic_min! => CUDA.atomic_min!, + atomic_op! => CUDA.atomic_op!, + atomic_or! => CUDA.atomic_or!, + atomic_sub! => CUDA.atomic_sub!, + atomic_xchg! => CUDA.atomic_xchg!, + atomic_xor! => CUDA.atomic_xor! +) + +for (afx, cfx) in afxs + @inline function Cassette.overdub(::CUDACtx, ::typeof(afx), args...) + cfx(args...) + end +end + end diff --git a/src/KernelAbstractions.jl b/src/KernelAbstractions.jl index 9d17a525e..b10a68004 100644 --- a/src/KernelAbstractions.jl +++ b/src/KernelAbstractions.jl @@ -482,6 +482,10 @@ include("extras/extras.jl") include("reflection.jl") +# Atomics + +include("atomics.jl") + # CPU backend include("cpu.jl") diff --git a/src/atomics.jl b/src/atomics.jl new file mode 100644 index 000000000..4913c0d51 --- /dev/null +++ b/src/atomics.jl @@ -0,0 +1,203 @@ +### +# Atomics +### + +export atomic_add!, atomic_sub!, atomic_and!, atomic_or!, atomic_xor!, + atomic_min!, atomic_max!, atomic_inc!, atomic_dec!, atomic_xchg!, + atomic_op!, atomic_cas! + +# helper functions for inc(rement) and dec(rement) +function dec(a::T,b::T) where T + ((a == 0) | (a > b)) ? b : (a-T(1)) +end + +function inc(a::T,b::T) where T + (a >= b) ? T(0) : (a+T(1)) +end + +# arithmetic, bitwise, min/max, and inc/dec operations +const ops = Dict( + :atomic_add! => +, + :atomic_sub! => -, + :atomic_and! => &, + :atomic_or! => |, + :atomic_xor! => ⊻, + :atomic_min! => min, + :atomic_max! => max, + :atomic_inc! => inc, + :atomic_dec! => dec, +) + +# Note: the type T prevents type convertion (for example, Float32 -> 64) +# can lead to errors if b is chosen to be of a different, compatible type +for (name, op) in ops + @eval @inline function $name(ptr::Ptr{T}, b::T) where T + Core.Intrinsics.atomic_pointermodify(ptr::Ptr{T}, $op, b::T, :monotonic) + end +end + +""" + atomic_cas!(ptr::Ptr{T}, cmp::T, val::T) + +This is an atomic Compare And Swap (CAS). +It reads the value `old` located at address `ptr` and compare with `cmp`. +If `old` equals `cmp`, it stores `val` at the same address. +Otherwise, doesn't change the value `old`. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Additionally, on GPU hardware with compute capability 7.0+, values of type UInt16 are supported. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +function atomic_cas!(ptr::Ptr{T}, old::T, new::T) where T + Core.Intrinsics.atomic_pointerreplace(ptr, old, new, :acquire_release, :monotonic) +end + +""" + atomic_xchg!(ptr::Ptr{T}, val::T) + +This is an atomic exchange. +It reads the value `old` located at address `ptr` and stores `val` at the same address. +These operations are performed in one atomic transaction. The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +function atomic_xchg!(ptr::Ptr{T}, b::T) where T + Core.Intrinsics.atomic_pointerswap(ptr::Ptr{T}, b::T, :monotonic) +end + +""" + atomic_op!(ptr::Ptr{T}, val::T) + +This is an arbitrary atomic operation. +It reads the value `old` located at address `ptr` and uses `val` in the operation `op` (defined elsewhere) +These operations are performed in one atomic transaction. The function returns `old`. + +This function is somewhat experimental. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +function atomic_op!(ptr::Ptr{T}, op, b::T) where T + Core.Intrinsics.atomic_pointermodify(ptr::Ptr{T}, op, b::T, :monotonic) +end + +# Other Documentation + +""" + atomic_add!(ptr::Ptr{T}, val::T) + +This is an atomic addition. +It reads the value `old` located at address `ptr`, computes `old + val`, and stores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32, UInt64, and Float32. +Additionally, on GPU hardware with compute capability 6.0+, values of type Float64 are supported. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_add! + +""" + atomic_sub!(ptr::Ptr{T}, val::T) + +This is an atomic subtraction. +It reads the value `old` located at address `ptr`, computes `old - val`, and stores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_sub! + +""" + atomic_and!(ptr::Ptr{T}, val::T) + +This is an atomic and. +It reads the value `old` located at address `ptr`, computes `old & val`, and stores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_and! + +""" + atomic_or!(ptr::Ptr{T}, val::T) + +This is an atomic or. +It reads the value `old` located at address `ptr`, computes `old | val`, and stores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_or! + +""" + atomic_xor!(ptr::Ptr{T}, val::T) + +This is an atomic xor. +It reads the value `old` located at address `ptr`, computes `old ⊻ val`, and stores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_xor! + +""" + atomic_min!(ptr::Ptr{T}, val::T) + +This is an atomic min. +It reads the value `old` located at address `ptr`, computes `min(old, val)`, and st ores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_min! + +""" + atomic_max!(ptr::Ptr{T}, val::T) + +This is an atomic max. +It reads the value `old` located at address `ptr`, computes `max(old, val)`, and st ores the result back to memory at the same address. +These operations are performed in one atomic transaction. +The function returns `old`. + +This operation is supported for values of type Int32, Int64, UInt32 and UInt64. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_max! + +""" + atomic_inc!(ptr::Ptr{T}, val::T) + +This is an atomic increment function that counts up to a certain number before starting again at 0. +It reads the value `old` located at address `ptr`, computes `((old >= val) ? 0 : (o ld+1))`, and stores the result back to memory at the same address. +These three operations are performed in one atomic transaction. +The function returns `old`. + +This operation is only supported for values of type Int32. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_inc! + +""" + atomic_dec!(ptr::Ptr{T}, val::T) + +This is an atomic decrement function that counts down to 0 from a defined value `val`. +It reads the value `old` located at address `ptr`, computes `(((old == 0) | (old > val)) ? val : (old-1))`, and stores the result back to memory at the same address. +These three operations are performed in one atomic transaction. +The function returns `old`. + +This operation is only supported for values of type Int32. +Also: atomic operations for the CPU requires a Julia version of 1.7.0 or above. +""" +atomic_dec! diff --git a/src/cpu.jl b/src/cpu.jl index cea40e586..1ae8b8afc 100644 --- a/src/cpu.jl +++ b/src/cpu.jl @@ -269,3 +269,30 @@ end # Argument conversion KernelAbstractions.argconvert(k::Kernel{CPU}, arg) = arg + +### +# CPU error handling if under 1.7 +### + +if Base.VERSION < v"1.7.0" + + import KernelAbstractions: atomic_add!, atomic_and!, atomic_cas!, + atomic_dec!, atomic_inc!, atomic_max!, + atomic_min!, atomic_op!, atomic_or!, + atomic_sub!, atomic_xchg!, atomic_xor! + + function atomic_error(args...) + error("CPU Atomics are not allowed for julia version under 1.7!") + end + + afxs = [atomic_add!, atomic_and!, atomic_cas!, atomic_dec!, + atomic_inc!, atomic_max!, atomic_min!, atomic_op!, + atomic_or!, atomic_sub!, atomic_xchg!, atomic_xor!] + + for afx in afxs + @inline function Cassette.overdub(::CPUCtx, ::typeof(afx), args...) + atomic_error(args...) + end + end +end + diff --git a/test/atomic_test.jl b/test/atomic_test.jl new file mode 100644 index 000000000..68eed69fc --- /dev/null +++ b/test/atomic_test.jl @@ -0,0 +1,207 @@ +using KernelAbstractions, Test + +# Note: kernels affect second element because some CPU defaults will affect the +# first element of a pointer if not specified, so I am covering the bases +@kernel function atomic_add_kernel!(input, b) + atomic_add!(pointer(input,2),b) +end + +@kernel function atomic_sub_kernel!(input, b) + atomic_sub!(pointer(input,2),b) +end + +@kernel function atomic_inc_kernel!(input, b) + atomic_inc!(pointer(input,2),b) +end + +@kernel function atomic_dec_kernel!(input, b) + atomic_dec!(pointer(input,2),b) +end + +@kernel function atomic_xchg_kernel!(input, b) + atomic_xchg!(pointer(input,2),b) +end + +@kernel function atomic_and_kernel!(input, b) + tid = @index(Global) + atomic_and!(pointer(input),b[tid]) +end + +@kernel function atomic_or_kernel!(input, b) + tid = @index(Global) + atomic_or!(pointer(input),b[tid]) +end + +@kernel function atomic_xor_kernel!(input, b) + tid = @index(Global) + atomic_xor!(pointer(input),b[tid]) +end + +@kernel function atomic_max_kernel!(input, b) + tid = @index(Global) + atomic_max!(pointer(input,2), b[tid]) +end + +@kernel function atomic_min_kernel!(input, b) + tid = @index(Global) + atomic_min!(pointer(input,2), b[tid]) +end + +@kernel function atomic_cas_kernel!(input, b, c) + atomic_cas!(pointer(input,2),b,c) +end + +function atomics_testsuite(backend, ArrayT) + + @testset "atomic addition tests" begin + types = [Int32, Int64, UInt32, UInt64, Float32] + + for T in types + A = ArrayT{T}([0,0]) + + kernel! = atomic_add_kernel!(backend(), 4) + wait(kernel!(A, one(T), ndrange=(1024))) + + @test Array(A)[2] == 1024 + end + end + + @testset "atomic subtraction tests" begin + types = [Int32, Int64, UInt32, UInt64, Float32] + + for T in types + A = ArrayT{T}([2048,2048]) + + kernel! = atomic_sub_kernel!(backend(), 4) + wait(kernel!(A, one(T), ndrange=(1024))) + + @test Array(A)[2] == 1024 + end + end + + @testset "atomic inc tests" begin + types = [Int32] + + for T in types + A = ArrayT{T}([0,0]) + + kernel! = atomic_inc_kernel!(backend(), 4) + wait(kernel!(A, T(512), ndrange=(768))) + + @test Array(A)[2] == 255 + end + end + + @testset "atomic dec tests" begin + types = [Int32] + + for T in types + A = ArrayT{T}([1024,1024]) + + kernel! = atomic_dec_kernel!(backend(), 4) + wait(kernel!(A, T(512), ndrange=(256))) + + @test Array(A)[2] == 257 + end + end + + @testset "atomic xchg tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([0,0]) + + kernel! = atomic_xchg_kernel!(backend(), 4) + wait(kernel!(A, T(1), ndrange=(256))) + + @test Array(A)[2] == one(T) + end + end + + @testset "atomic and tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([1023]) + B = ArrayT{T}([1023-2^(i-1) for i = 1:10]) + + kernel! = atomic_and_kernel!(backend(), 4) + wait(kernel!(A, B, ndrange=length(B))) + + @test Array(A)[1] == zero(T) + end + end + + @testset "atomic or tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([0]) + B = ArrayT{T}([2^(i-1) for i = 1:10]) + + kernel! = atomic_or_kernel!(backend(), 4) + wait(kernel!(A, B, ndrange=length(B))) + + @test Array(A)[1] == T(1023) + end + end + + @testset "atomic xor tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([1023]) + B = ArrayT{T}([2^(i-1) for i = 1:10]) + + kernel! = atomic_xor_kernel!(backend(), 4) + wait(kernel!(A, B, ndrange=length(B))) + + @test Array(A)[1] == T(0) + end + end + + @testset "atomic max tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([0,0]) + B = ArrayT{T}([i for i = 1:1024]) + + kernel! = atomic_max_kernel!(backend(), 4) + wait(kernel!(A, B, ndrange=length(B))) + + @test Array(A)[2] == T(1024) + end + end + + @testset "atomic min tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([1024,1024]) + B = ArrayT{T}([i for i = 1:1024]) + + kernel! = atomic_min_kernel!(backend(), 4) + wait(kernel!(A, B, ndrange=length(B))) + + @test Array(A)[2] == T(1) + end + end + + + @testset "atomic cas tests" begin + types = [Int32, Int64, UInt32, UInt64] + + for T in types + A = ArrayT{T}([0,0]) + + kernel! = atomic_cas_kernel!(backend(), 4) + wait(kernel!(A, zero(T), one(T), ndrange=(1024))) + + @test Array(A)[2] == 1 + end + end + + + +end diff --git a/test/testsuite.jl b/test/testsuite.jl index c38c9baeb..8dab59c7b 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -15,6 +15,7 @@ include("compiler.jl") include("reflection.jl") include("examples.jl") include("convert.jl") +include("atomic_test.jl") function testsuite(backend, backend_str, backend_mod, AT, DAT) @testset "Unittests" begin @@ -69,6 +70,13 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT) convert_testsuite(backend, AT) end + if backend_str != "ROCM" && + !(backend_str == "CPU" && Base.VERSION < v"1.7.0") + @testset "Atomics" begin + atomics_testsuite(backend, AT) + end + end + if backend_str == "CUDA" @testset "Examples" begin examples_testsuite()