diff --git a/src/SparseArrays.jl b/src/SparseArrays.jl index 2688a449..b4bccf16 100644 --- a/src/SparseArrays.jl +++ b/src/SparseArrays.jl @@ -22,7 +22,7 @@ import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot, import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex, conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin, - float, getindex, imag, inv, kron, kron!, length, map, maximum, minimum, permute!, real, + float, getindex, imag, inv, keytype, kron, kron!, length, map, maximum, minimum, permute!, real, rot180, rotl90, rotr90, setindex!, show, similar, size, sum, transpose, vcat, hcat, hvcat, cat, vec, reverse, reverse! @@ -84,7 +84,8 @@ if Base.USE_GPL_LIBS include("solvers/spqr.jl") end -zero(a::AbstractSparseArray) = spzeros(eltype(a), size(a)...) +keytype(::Type{A}) where {Tv, Ti, A<:AbstractSparseArray{Tv,Ti}} = Ti +zero(a::AbstractSparseArray) = spzeros(eltype(a), keytype(a), size(a)...) LinearAlgebra.diagzero(D::Diagonal{<:AbstractSparseMatrix{T}},i,j) where {T} = spzeros(T, size(D.diag[i], 1), size(D.diag[j], 2)) diff --git a/test/sparsevector.jl b/test/sparsevector.jl index c91f0d69..42a1b9d0 100644 --- a/test/sparsevector.jl +++ b/test/sparsevector.jl @@ -12,6 +12,7 @@ include("forbidproperties.jl") ### Data spv_x1 = SparseVector(8, [2, 5, 6], [1.25, -0.75, 3.5]) +spv_x1_32 = SparseVector(8, Int32[2, 5, 6], Float32[1.25, -0.75, 3.5]) @test isa(spv_x1, SparseVector{Float64,Int}) @@ -42,6 +43,14 @@ x1_full[SparseArrays.nonzeroinds(spv_x1)] = nonzeros(spv_x1) @test @inferred size(y) == (@inferred(length(y))::Int8,) end +@testset "Non default index type" begin + x = spv_x1_32 + for func in [identity, copy, empty, similar, zero] + @test eltype(func(spv_x1_32)) == Float32 + @test keytype(func(spv_x1_32)) == Int32 + end +end + @testset "isstored" begin x = spv_x1 stored_inds = [2, 5, 6]