Skip to content

Commit 16bbcbc

Browse files
authored
5-term mul! with Diagonal (#603)
This speeds up certain 5-term multiplications involving sparse matrices and a `Diagonal`, as the operations become O(nnz(A)) from O(N^2). ```julia julia> using SparseArrays, LinearAlgebra, Chairmarks julia> S = sprand(5_000, 5_000, 0.00001); julia> S2 = similar(S); julia> D = Diagonal(axes(S,1)); julia> @b (S2,S,D) mul!(_[1], _[2], _[3], true, false) 190.253 ms # main 7.143 μs # PR julia> S2 = sprand(5_000, 5_000, 0.00001); julia> @b (S2,S,D) mul!(_[1], _[2], _[3], 2, 3) 219.410 ms (without a warmup) # main 23.085 μs # PR ```
1 parent d050b1b commit 16bbcbc

File tree

3 files changed

+212
-20
lines changed

3 files changed

+212
-20
lines changed

src/linalg.jl

Lines changed: 175 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,47 +1877,204 @@ inv(A::AbstractSparseMatrixCSC) = error("The inverse of a sparse matrix can ofte
18771877
## scale methods
18781878

18791879
# Copy colptr and rowval from one sparse matrix to another
1880-
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1881-
if getcolptr(C) !== getcolptr(A)
1880+
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC; copy_rows=true, copy_cols=true)
1881+
if copy_cols && getcolptr(C) !== getcolptr(A)
18821882
resize!(getcolptr(C), length(getcolptr(A)))
18831883
copyto!(getcolptr(C), getcolptr(A))
18841884
end
1885-
if rowvals(C) !== rowvals(A)
1885+
if copy_rows && rowvals(C) !== rowvals(A)
18861886
resize!(rowvals(C), length(rowvals(A)))
18871887
copyto!(rowvals(C), rowvals(A))
18881888
end
18891889
end
18901890

1891+
"""
1892+
rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
1893+
1894+
Check if A[row, col] is a stored value, and return the index of the row in `rowvals(A)`.
1895+
Returns `(row_exists, row_ind)`, where `row_exists::Bool` signifies
1896+
whether the corresponding index is populated, and `row_ind` is the index.
1897+
If `row_exists` is `false`, the `row_ind` is the index where the value should be inserted into
1898+
`rowvals(A)` such that the subarray `@view rowvals(A)[nzrange(A, col)]` remains sorted.
1899+
"""
1900+
@inline function rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
1901+
nzinds = nzrange(A, col)
1902+
rows_col = @view rowvals(A)[nzinds]
1903+
# faster implementation of row ∈ rows_col and obtaining the index,
1904+
# assuming that rows_col is sorted
1905+
row_ind_col = searchsortedfirst(rows_col, row)
1906+
row_exists = row_ind_col axes(rows_col,1) && rows_col[row_ind_col] == row
1907+
row_ind = row_ind_col + first(nzinds) - firstindex(nzinds)
1908+
row_exists, row_ind
1909+
end
1910+
1911+
"""
1912+
mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1913+
1914+
Update `C` to contain stored values corresponding to the stored indices of `A`.
1915+
Stored indices common to `C` and `A` are not touched. Indices of `A` at which
1916+
`C` did not have a stored value are populated with zeros after the call.
1917+
1918+
# Examples
1919+
```jldoctest
1920+
julia> A = spzeros(3,3);
1921+
1922+
julia> A[4:4:8] .= 1;
1923+
1924+
julia> A
1925+
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1926+
⋅ 1.0 ⋅
1927+
⋅ ⋅ 1.0
1928+
⋅ ⋅ ⋅
1929+
1930+
julia> C = spzeros(3,3);
1931+
1932+
julia> C[2:4:6] .= 2;
1933+
1934+
julia> C
1935+
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1936+
⋅ ⋅ ⋅
1937+
2.0 ⋅ ⋅
1938+
⋅ 2.0 ⋅
1939+
1940+
julia> SparseArrays.mergeinds!(C, A)
1941+
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
1942+
⋅ 0.0 ⋅
1943+
2.0 ⋅ 0.0
1944+
⋅ 2.0 ⋅
1945+
```
1946+
"""
1947+
function mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1948+
C_colptr = getcolptr(C)
1949+
for col in axes(A,2)
1950+
n_extra = 0
1951+
for ind in nzrange(A, col)
1952+
row = rowvals(A)[ind]
1953+
row_exists, ind = rowcheck_index(C, row, col)
1954+
if !row_exists
1955+
n_extra += 1
1956+
insert!(rowvals(C), ind, row)
1957+
insert!(nonzeros(C), ind, zero(eltype(C)))
1958+
C_colptr[col+1] += 1
1959+
end
1960+
end
1961+
if !iszero(n_extra)
1962+
@views C_colptr[col+2:end] .+= n_extra
1963+
end
1964+
end
1965+
C
1966+
end
1967+
18911968
# multiply by diagonal matrix as vector
1892-
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal)
1969+
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal, alpha::Number, beta::Number)
18931970
m, n = size(A)
1894-
b = D.diag
1971+
b = D.diag
18951972
lb = length(b)
1896-
n == lb || throw(DimensionMismatch("A has size ($m, $n) but D has size ($lb, $lb)"))
1897-
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1898-
copyinds!(C, A)
1973+
n == lb || throw(DimensionMismatch(lazy"A has size ($m, $n) but D has size ($lb, $lb)"))
1974+
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1975+
beta_is_zero = iszero(beta)
1976+
rows_match = rowvals(C) == rowvals(A)
1977+
cols_match = getcolptr(C) == getcolptr(A)
1978+
identical_nzinds = rows_match && cols_match
18991979
Cnzval = nonzeros(C)
19001980
Anzval = nonzeros(A)
1901-
resize!(Cnzval, length(Anzval))
1902-
for col in axes(A,2), p in nzrange(A, col)
1903-
@inbounds Cnzval[p] = Anzval[p] * b[col]
1981+
if beta_is_zero || identical_nzinds
1982+
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
1983+
resize!(Cnzval, length(Anzval))
1984+
if beta_is_zero
1985+
if isone(alpha)
1986+
for col in axes(A,2), p in nzrange(A, col)
1987+
@inbounds Cnzval[p] = Anzval[p] * b[col]
1988+
end
1989+
else
1990+
for col in axes(A,2), p in nzrange(A, col)
1991+
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha
1992+
end
1993+
end
1994+
else
1995+
if isone(alpha)
1996+
for col in axes(A,2), p in nzrange(A, col)
1997+
@inbounds Cnzval[p] = Anzval[p] * b[col] + Cnzval[p] * beta
1998+
end
1999+
else
2000+
for col in axes(A,2), p in nzrange(A, col)
2001+
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha + Cnzval[p] * beta
2002+
end
2003+
end
2004+
end
2005+
else
2006+
mergeinds!(C, A)
2007+
for col in axes(C,2), p in nzrange(C, col)
2008+
row = rowvals(C)[p]
2009+
# check if the index (row, col) is stored in A
2010+
row_exists, row_ind_A = rowcheck_index(A, row, col)
2011+
if row_exists
2012+
if isone(alpha)
2013+
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] + Cnzval[p] * beta
2014+
else
2015+
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] * alpha + Cnzval[p] * beta
2016+
end
2017+
else # A[row,col] == 0
2018+
@inbounds Cnzval[p] = Cnzval[p] * beta
2019+
end
2020+
end
19042021
end
19052022
C
19062023
end
19072024

1908-
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC)
2025+
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC, alpha::Number, beta::Number)
19092026
m, n = size(A)
19102027
b = D.diag
19112028
lb = length(b)
1912-
m == lb || throw(DimensionMismatch("D has size ($lb, $lb) but A has size ($m, $n)"))
1913-
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
1914-
copyinds!(C, A)
2029+
m == lb || throw(DimensionMismatch(lazy"D has size ($lb, $lb) but A has size ($m, $n)"))
2030+
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
2031+
beta_is_zero = iszero(beta)
2032+
rows_match = rowvals(C) == rowvals(A)
2033+
cols_match = getcolptr(C) == getcolptr(A)
2034+
identical_nzinds = rows_match && cols_match
19152035
Cnzval = nonzeros(C)
19162036
Anzval = nonzeros(A)
19172037
Arowval = rowvals(A)
1918-
resize!(Cnzval, length(Anzval))
1919-
for col in axes(A,2), p in nzrange(A, col)
1920-
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2038+
if beta_is_zero || identical_nzinds
2039+
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
2040+
resize!(Cnzval, length(Anzval))
2041+
if beta_is_zero
2042+
if isone(alpha)
2043+
for col in axes(A,2), p in nzrange(A, col)
2044+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2045+
end
2046+
else
2047+
for col in axes(A,2), p in nzrange(A, col)
2048+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha
2049+
end
2050+
end
2051+
else
2052+
if isone(alpha)
2053+
for col in axes(A,2), p in nzrange(A, col)
2054+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] + Cnzval[p] * beta
2055+
end
2056+
else
2057+
for col in axes(A,2), p in nzrange(A, col)
2058+
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha + Cnzval[p] * beta
2059+
end
2060+
end
2061+
end
2062+
else
2063+
mergeinds!(C, A)
2064+
for col in axes(C,2), p in nzrange(C, col)
2065+
row = rowvals(C)[p]
2066+
# check if the index (row, col) is stored in A
2067+
row_exists, row_ind_A = rowcheck_index(A, row, col)
2068+
if row_exists
2069+
if isone(alpha)
2070+
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] + Cnzval[p] * beta
2071+
else
2072+
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] * alpha + Cnzval[p] * beta
2073+
end
2074+
else # A[row,col] == 0
2075+
@inbounds Cnzval[p] = Cnzval[p] * beta
2076+
end
2077+
end
19212078
end
19222079
C
19232080
end

test/issues.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ end
445445
A = sprand(5,5,0.5)
446446
D = Diagonal(rand(5))
447447
C = copy(A)
448-
m1 = @which mul!(C,A,D)
449-
m2 = @which mul!(C,D,A)
448+
m1 = @which mul!(C,A,D,true,false)
449+
m2 = @which mul!(C,D,A,true,false)
450450
@test m1.module == SparseArrays
451451
@test m2.module == SparseArrays
452452
end

test/linalg.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,41 @@ end
604604
@test lmul!(D, copy(sA)) D * dA
605605
@test mul!(sC, D, copy(sA)) D * dA
606606
end
607+
608+
@testset "5-arg mul!" begin
609+
@testset "merge indices" begin
610+
# for zero arrays, merge and copy are identical
611+
A = spzeros(size(sA))
612+
SparseArrays.mergeinds!(A, sA)
613+
B = spzeros(size(sA))
614+
SparseArrays.copyinds!(B, sA)
615+
@test all(col -> nzrange(A, col) == nzrange(B, col), axes(A,2))
616+
# for arrays with different indices populated, merge should combine these
617+
A = spzeros(5,5)
618+
A[diagind(A,1)] .= 5
619+
B = spzeros(5,5)
620+
B[diagind(A,-1)] .= 10
621+
SparseArrays.mergeinds!(B, A)
622+
@test rowvals(B) == [2, 1,3, 2,4, 3,5, 4]
623+
@test [nzrange(B,col) for col in axes(B,2)] == [1:1, 2:3, 4:5, 6:7, 8:8]
624+
@test nonzeros(B) == [10, 0,10, 0,10, 0,10, 0]
625+
# for arrays with overlapping indices, merge should only add the extra ones
626+
A[diagind(A,2)] .= 5
627+
SparseArrays.mergeinds!(B, A)
628+
@test rowvals(B) == [2, 1,3, 1,2,4, 2,3,5, 3,4]
629+
@test [nzrange(B,col) for col in axes(B,2)] == [1:1, 2:3, 4:6, 7:9, 10:11]
630+
@test nonzeros(B) == [10, 0,10, 0,0,10, 0,0,10, 0,0]
631+
end
632+
for sA2 in (similar(sA), sprand(size(sA)..., 0.1))
633+
nonzeros(sA2) .= 1
634+
@testset for (alpha, beta) in [(true, false), (true, true), (2,3)]
635+
D = Diagonal(rand(size(sA,2)))
636+
@test mul!(copy(sA2), sA, D, alpha, beta) dA * D * alpha + sA2 * beta
637+
D = Diagonal(rand(size(sA,1)))
638+
@test mul!(copy(sA2), D, sA, alpha, beta) D * dA * alpha + sA2 * beta
639+
end
640+
end
641+
end
607642
end
608643

609644
@testset "conj" begin

0 commit comments

Comments
 (0)