@@ -1877,47 +1877,204 @@ inv(A::AbstractSparseMatrixCSC) = error("The inverse of a sparse matrix can ofte
1877
1877
# # scale methods
1878
1878
1879
1879
# Copy colptr and rowval from one sparse matrix to another
1880
- function copyinds! (C:: AbstractSparseMatrixCSC , A:: AbstractSparseMatrixCSC )
1881
- if getcolptr (C) != = getcolptr (A)
1880
+ function copyinds! (C:: AbstractSparseMatrixCSC , A:: AbstractSparseMatrixCSC ; copy_rows = true , copy_cols = true )
1881
+ if copy_cols && getcolptr (C) != = getcolptr (A)
1882
1882
resize! (getcolptr (C), length (getcolptr (A)))
1883
1883
copyto! (getcolptr (C), getcolptr (A))
1884
1884
end
1885
- if rowvals (C) != = rowvals (A)
1885
+ if copy_rows && rowvals (C) != = rowvals (A)
1886
1886
resize! (rowvals (C), length (rowvals (A)))
1887
1887
copyto! (rowvals (C), rowvals (A))
1888
1888
end
1889
1889
end
1890
1890
1891
+ """
1892
+ rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
1893
+
1894
+ Check if A[row, col] is a stored value, and return the index of the row in `rowvals(A)`.
1895
+ Returns `(row_exists, row_ind)`, where `row_exists::Bool` signifies
1896
+ whether the corresponding index is populated, and `row_ind` is the index.
1897
+ If `row_exists` is `false`, the `row_ind` is the index where the value should be inserted into
1898
+ `rowvals(A)` such that the subarray `@view rowvals(A)[nzrange(A, col)]` remains sorted.
1899
+ """
1900
+ @inline function rowcheck_index (A:: AbstractSparseMatrixCSC , row:: Integer , col:: Integer )
1901
+ nzinds = nzrange (A, col)
1902
+ rows_col = @view rowvals (A)[nzinds]
1903
+ # faster implementation of row ∈ rows_col and obtaining the index,
1904
+ # assuming that rows_col is sorted
1905
+ row_ind_col = searchsortedfirst (rows_col, row)
1906
+ row_exists = row_ind_col ∈ axes (rows_col,1 ) && rows_col[row_ind_col] == row
1907
+ row_ind = row_ind_col + first (nzinds) - firstindex (nzinds)
1908
+ row_exists, row_ind
1909
+ end
1910
+
1911
+ """
1912
+ mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
1913
+
1914
+ Update `C` to contain stored values corresponding to the stored indices of `A`.
1915
+ Stored indices common to `C` and `A` are not touched. Indices of `A` at which
1916
+ `C` did not have a stored value are populated with zeros after the call.
1917
+
1918
+ # Examples
1919
+ ```jldoctest
1920
+ julia> A = spzeros(3,3);
1921
+
1922
+ julia> A[4:4:8] .= 1;
1923
+
1924
+ julia> A
1925
+ 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1926
+ ⋅ 1.0 ⋅
1927
+ ⋅ ⋅ 1.0
1928
+ ⋅ ⋅ ⋅
1929
+
1930
+ julia> C = spzeros(3,3);
1931
+
1932
+ julia> C[2:4:6] .= 2;
1933
+
1934
+ julia> C
1935
+ 3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
1936
+ ⋅ ⋅ ⋅
1937
+ 2.0 ⋅ ⋅
1938
+ ⋅ 2.0 ⋅
1939
+
1940
+ julia> SparseArrays.mergeinds!(C, A)
1941
+ 3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
1942
+ ⋅ 0.0 ⋅
1943
+ 2.0 ⋅ 0.0
1944
+ ⋅ 2.0 ⋅
1945
+ ```
1946
+ """
1947
+ function mergeinds! (C:: AbstractSparseMatrixCSC , A:: AbstractSparseMatrixCSC )
1948
+ C_colptr = getcolptr (C)
1949
+ for col in axes (A,2 )
1950
+ n_extra = 0
1951
+ for ind in nzrange (A, col)
1952
+ row = rowvals (A)[ind]
1953
+ row_exists, ind = rowcheck_index (C, row, col)
1954
+ if ! row_exists
1955
+ n_extra += 1
1956
+ insert! (rowvals (C), ind, row)
1957
+ insert! (nonzeros (C), ind, zero (eltype (C)))
1958
+ C_colptr[col+ 1 ] += 1
1959
+ end
1960
+ end
1961
+ if ! iszero (n_extra)
1962
+ @views C_colptr[col+ 2 : end ] .+ = n_extra
1963
+ end
1964
+ end
1965
+ C
1966
+ end
1967
+
1891
1968
# multiply by diagonal matrix as vector
1892
- function mul! (C:: AbstractSparseMatrixCSC , A:: AbstractSparseMatrixCSC , D:: Diagonal )
1969
+ function mul! (C:: AbstractSparseMatrixCSC , A:: AbstractSparseMatrixCSC , D:: Diagonal , alpha :: Number , beta :: Number )
1893
1970
m, n = size (A)
1894
- b = D. diag
1971
+ b = D. diag
1895
1972
lb = length (b)
1896
- n == lb || throw (DimensionMismatch (" A has size ($m , $n ) but D has size ($lb , $lb )" ))
1897
- size (A)== size (C) || throw (DimensionMismatch (" A has size ($m , $n ), D has size ($lb , $lb ), C has size $(size (C)) " ))
1898
- copyinds! (C, A)
1973
+ n == lb || throw (DimensionMismatch (lazy " A has size ($m, $n) but D has size ($lb, $lb)" ))
1974
+ size (A)== size (C) || throw (DimensionMismatch (lazy " A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))" ))
1975
+ beta_is_zero = iszero (beta)
1976
+ rows_match = rowvals (C) == rowvals (A)
1977
+ cols_match = getcolptr (C) == getcolptr (A)
1978
+ identical_nzinds = rows_match && cols_match
1899
1979
Cnzval = nonzeros (C)
1900
1980
Anzval = nonzeros (A)
1901
- resize! (Cnzval, length (Anzval))
1902
- for col in axes (A,2 ), p in nzrange (A, col)
1903
- @inbounds Cnzval[p] = Anzval[p] * b[col]
1981
+ if beta_is_zero || identical_nzinds
1982
+ identical_nzinds || copyinds! (C, A, copy_rows = ! rows_match, copy_cols = ! cols_match)
1983
+ resize! (Cnzval, length (Anzval))
1984
+ if beta_is_zero
1985
+ if isone (alpha)
1986
+ for col in axes (A,2 ), p in nzrange (A, col)
1987
+ @inbounds Cnzval[p] = Anzval[p] * b[col]
1988
+ end
1989
+ else
1990
+ for col in axes (A,2 ), p in nzrange (A, col)
1991
+ @inbounds Cnzval[p] = Anzval[p] * b[col] * alpha
1992
+ end
1993
+ end
1994
+ else
1995
+ if isone (alpha)
1996
+ for col in axes (A,2 ), p in nzrange (A, col)
1997
+ @inbounds Cnzval[p] = Anzval[p] * b[col] + Cnzval[p] * beta
1998
+ end
1999
+ else
2000
+ for col in axes (A,2 ), p in nzrange (A, col)
2001
+ @inbounds Cnzval[p] = Anzval[p] * b[col] * alpha + Cnzval[p] * beta
2002
+ end
2003
+ end
2004
+ end
2005
+ else
2006
+ mergeinds! (C, A)
2007
+ for col in axes (C,2 ), p in nzrange (C, col)
2008
+ row = rowvals (C)[p]
2009
+ # check if the index (row, col) is stored in A
2010
+ row_exists, row_ind_A = rowcheck_index (A, row, col)
2011
+ if row_exists
2012
+ if isone (alpha)
2013
+ @inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] + Cnzval[p] * beta
2014
+ else
2015
+ @inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] * alpha + Cnzval[p] * beta
2016
+ end
2017
+ else # A[row,col] == 0
2018
+ @inbounds Cnzval[p] = Cnzval[p] * beta
2019
+ end
2020
+ end
1904
2021
end
1905
2022
C
1906
2023
end
1907
2024
1908
- function mul! (C:: AbstractSparseMatrixCSC , D:: Diagonal , A:: AbstractSparseMatrixCSC )
2025
+ function mul! (C:: AbstractSparseMatrixCSC , D:: Diagonal , A:: AbstractSparseMatrixCSC , alpha :: Number , beta :: Number )
1909
2026
m, n = size (A)
1910
2027
b = D. diag
1911
2028
lb = length (b)
1912
- m == lb || throw (DimensionMismatch (" D has size ($lb , $lb ) but A has size ($m , $n )" ))
1913
- size (A)== size (C) || throw (DimensionMismatch (" A has size ($m , $n ), D has size ($lb , $lb ), C has size $(size (C)) " ))
1914
- copyinds! (C, A)
2029
+ m == lb || throw (DimensionMismatch (lazy " D has size ($lb, $lb) but A has size ($m, $n)" ))
2030
+ size (A)== size (C) || throw (DimensionMismatch (lazy " A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))" ))
2031
+ beta_is_zero = iszero (beta)
2032
+ rows_match = rowvals (C) == rowvals (A)
2033
+ cols_match = getcolptr (C) == getcolptr (A)
2034
+ identical_nzinds = rows_match && cols_match
1915
2035
Cnzval = nonzeros (C)
1916
2036
Anzval = nonzeros (A)
1917
2037
Arowval = rowvals (A)
1918
- resize! (Cnzval, length (Anzval))
1919
- for col in axes (A,2 ), p in nzrange (A, col)
1920
- @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2038
+ if beta_is_zero || identical_nzinds
2039
+ identical_nzinds || copyinds! (C, A, copy_rows = ! rows_match, copy_cols = ! cols_match)
2040
+ resize! (Cnzval, length (Anzval))
2041
+ if beta_is_zero
2042
+ if isone (alpha)
2043
+ for col in axes (A,2 ), p in nzrange (A, col)
2044
+ @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
2045
+ end
2046
+ else
2047
+ for col in axes (A,2 ), p in nzrange (A, col)
2048
+ @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha
2049
+ end
2050
+ end
2051
+ else
2052
+ if isone (alpha)
2053
+ for col in axes (A,2 ), p in nzrange (A, col)
2054
+ @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] + Cnzval[p] * beta
2055
+ end
2056
+ else
2057
+ for col in axes (A,2 ), p in nzrange (A, col)
2058
+ @inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha + Cnzval[p] * beta
2059
+ end
2060
+ end
2061
+ end
2062
+ else
2063
+ mergeinds! (C, A)
2064
+ for col in axes (C,2 ), p in nzrange (C, col)
2065
+ row = rowvals (C)[p]
2066
+ # check if the index (row, col) is stored in A
2067
+ row_exists, row_ind_A = rowcheck_index (A, row, col)
2068
+ if row_exists
2069
+ if isone (alpha)
2070
+ @inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] + Cnzval[p] * beta
2071
+ else
2072
+ @inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] * alpha + Cnzval[p] * beta
2073
+ end
2074
+ else # A[row,col] == 0
2075
+ @inbounds Cnzval[p] = Cnzval[p] * beta
2076
+ end
2077
+ end
1921
2078
end
1922
2079
C
1923
2080
end
0 commit comments