Skip to content

Commit dd7e033

Browse files
authored
Resolve stack-overflow in Diagonal*Bidiagonal and (Sym)Tridiagonal (#242)
* Resolve stack-overflow in Diagonal*Bidiagonal and (Sym)Tridiagonal * Forward to copyto! * Restrict test to version v1.11 * Update copy test * Update more tests * Add test for general method * Bump version to v1.10.2
1 parent 40dc4f3 commit dd7e033

File tree

3 files changed

+85
-36
lines changed

3 files changed

+85
-36
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <solver@mac.com>"]
4-
version = "1.10.3"
4+
version = "1.10.4"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/diagonal.jl

+30-6
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,36 @@ copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = diagona
5757

5858

5959
## bi/tridiagonal copy
60-
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout}) = convert(Bidiagonal, M.A) * M.B
61-
copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout}) = M.A * convert(Bidiagonal, M.B)
62-
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout}) = convert(Tridiagonal, M.A) * M.B
63-
copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout}) = M.A * convert(Tridiagonal, M.B)
64-
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout}) = convert(SymTridiagonal, M.A) * M.B
65-
copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout}) = M.A * convert(SymTridiagonal, M.B)
60+
# hack around the fact that a SymTridiagonal isn't fully mutable
61+
_similar(A) = similar(A)
62+
_similar(A::SymTridiagonal) = similar(Tridiagonal(A.ev, A.dv, A.ev))
63+
_copy_diag(M::T, ::T) where {T<:Rmul} = copyto!(_similar(M.A), M)
64+
_copy_diag(M::T, ::T) where {T<:Lmul} = copyto!(_similar(M.B), M)
65+
_copy_diag(M, _) = copy(M)
66+
function copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout})
67+
A = convert(Bidiagonal, M.A)
68+
_copy_diag(Rmul(A, M.B), M)
69+
end
70+
function copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout})
71+
B = convert(Bidiagonal, M.B)
72+
_copy_diag(Lmul(M.A, B), M)
73+
end
74+
function copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout})
75+
A = convert(Tridiagonal, M.A)
76+
_copy_diag(Rmul(A, M.B), M)
77+
end
78+
function copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout})
79+
B = convert(Tridiagonal, M.B)
80+
_copy_diag(Lmul(M.A, B), M)
81+
end
82+
function copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout})
83+
A = convert(SymTridiagonal, M.A)
84+
_copy_diag(Rmul(A, M.B), M)
85+
end
86+
function copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout})
87+
B = convert(SymTridiagonal, M.B)
88+
_copy_diag(Lmul(M.A, B), M)
89+
end
6690

6791
copy(M::Lmul{DiagonalLayout{OnesLayout}}) = _copy_oftype(M.B, eltype(M))
6892
copy(M::Lmul{DiagonalLayout{OnesLayout},<:DiagonalLayout}) = Diagonal(_copy_oftype(diagonaldata(M.B), eltype(M)))

test/test_layoutarray.jl

+54-29
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,40 @@ module TestLayoutArray
33
using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays, Random
44
using ArrayLayouts: sub_materialize, MemoryLayout, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul, Mul
55
import ArrayLayouts: triangulardata
6+
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
67

7-
struct MyMatrix <: LayoutMatrix{Float64}
8-
A::Matrix{Float64}
8+
struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
9+
A::M
910
end
1011

1112
Base.getindex(A::MyMatrix, k::Int, j::Int) = A.A[k,j]
1213
Base.setindex!(A::MyMatrix, v, k::Int, j::Int) = setindex!(A.A, v, k, j)
1314
Base.size(A::MyMatrix) = size(A.A)
1415
Base.strides(A::MyMatrix) = strides(A.A)
15-
Base.elsize(::Type{MyMatrix}) = sizeof(Float64)
16-
Base.cconvert(::Type{Ptr{Float64}}, A::MyMatrix) = A.A
17-
Base.unsafe_convert(::Type{Ptr{Float64}}, A::MyMatrix) = Base.unsafe_convert(Ptr{Float64}, A.A)
18-
MemoryLayout(::Type{MyMatrix}) = DenseColumnMajor()
16+
Base.elsize(::Type{<:MyMatrix{T}}) where {T} = sizeof(T)
17+
Base.cconvert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
18+
Base.unsafe_convert(::Type{Ptr{T}}, A::MyMatrix{T}) where {T} = Base.unsafe_convert(Ptr{T}, A.A)
19+
MemoryLayout(::Type{MyMatrix{T,M}}) where {T,M} = MemoryLayout(M)
1920
Base.copy(A::MyMatrix) = MyMatrix(copy(A.A))
21+
ArrayLayouts.bidiagonaluplo(M::MyMatrix) = ArrayLayouts.bidiagonaluplo(M.A)
22+
for MT in (:Diagonal, :Bidiagonal, :Tridiagonal, :SymTridiagonal)
23+
@eval $MT(M::MyMatrix) = $MT(M.A)
24+
end
2025

21-
struct MyVector{T} <: LayoutVector{T}
22-
A::Vector{T}
26+
struct MyVector{T,V<:AbstractVector{T}} <: LayoutVector{T}
27+
A::V
2328
end
2429

30+
MyVector(M::MyVector) = MyVector(M.A)
2531
Base.getindex(A::MyVector, k::Int) = A.A[k]
2632
Base.setindex!(A::MyVector, v, k::Int) = setindex!(A.A, v, k)
2733
Base.size(A::MyVector) = size(A.A)
2834
Base.strides(A::MyVector) = strides(A.A)
29-
Base.elsize(::Type{MyVector}) = sizeof(Float64)
30-
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = A.A
35+
Base.elsize(::Type{<:MyVector{T}}) where {T} = sizeof(T)
36+
Base.cconvert(::Type{Ptr{T}}, A::MyVector{T}) where {T} = Base.cconvert(Ptr{T}, A.A)
3137
Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)
32-
MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
38+
MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V)
39+
Base.copy(A::MyVector) = MyVector(copy(A.A))
3340

3441
# These need to test dispatch reduces to ArrayLayouts.mul, etc.
3542
@testset "LayoutArray" begin
@@ -42,7 +49,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
4249
@test a[1:3] == a.A[1:3]
4350
@test a[:] == a
4451
@test (a')[1,:] == (a')[1,1:3] == a
45-
@test sprint(show, "text/plain", a) == "3-element $MyVector{Float64}:\n 1.0\n 2.0\n 3.0"
52+
@test sprint(show, "text/plain", a) == "$(summary(a)):\n 1.0\n 2.0\n 3.0"
4653
@test B*a B*a.A
4754
@test B'*a B'*a.A
4855
@test transpose(B)*a transpose(B)*a.A
@@ -104,8 +111,8 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
104111
@test_throws ErrorException qr!(A)
105112
@test lu!(copy(A)).factors lu(A.A).factors
106113
b = randn(5)
107-
@test A \ b == A.A \ b == A.A \ MyVector(b) == ldiv!(lu(A.A), copy(MyVector(b)))
108-
@test A \ b == ldiv!(lu(A), copy(MyVector(b))) == ldiv!(lu(A), copy(b))
114+
@test A \ b == A.A \ b == A.A \ MyVector(b) == ldiv!(lu(A.A), copy(b))
115+
@test A \ b == ldiv!(lu(A), copy(b))
109116
@test lu(A).L == lu(A.A).L
110117
@test lu(A).U == lu(A.A).U
111118
@test lu(A).p == lu(A.A).p
@@ -120,7 +127,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
120127
@test cholesky!(deepcopy(S), CRowMaximum()).U cholesky(Matrix(S), CRowMaximum()).U
121128
@test cholesky(S) \ b cholesky(Matrix(S)) \ b cholesky(Matrix(S)) \ MyVector(b)
122129
@test cholesky(S, CRowMaximum()) \ b cholesky(Matrix(S), CRowMaximum()) \ b
123-
@test cholesky(S, CRowMaximum()) \ b ldiv!(cholesky(Matrix(S), CRowMaximum()), copy(MyVector(b)))
130+
@test cholesky(S, CRowMaximum()) \ b ldiv!(cholesky(Matrix(S), CRowMaximum()), copy(b))
124131
@test cholesky(S) \ b Matrix(S) \ b Symmetric(Matrix(S)) \ b
125132
@test cholesky(S) \ b Symmetric(Matrix(S)) \ MyVector(b)
126133
if VERSION >= v"1.9"
@@ -140,18 +147,9 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
140147

141148
@testset "ldiv!" begin
142149
c = MyVector(randn(5))
143-
if VERSION < v"1.9"
144-
@test_broken ldiv!(lu(A), MyVector(copy(c))) A \ c
145-
else
146-
@test ldiv!(lu(A), MyVector(copy(c))) A \ c
147-
end
148-
if VERSION < v"1.9" || VERSION >= v"1.10-"
149-
@test_throws ErrorException ldiv!(qr(A), MyVector(copy(c)))
150-
else
151-
@test_throws MethodError ldiv!(qr(A), MyVector(copy(c)))
152-
end
150+
@test ldiv!(lu(A), MyVector(copy(c))) A \ c
153151
@test_throws ErrorException ldiv!(eigen(randn(5,5)), c)
154-
@test ArrayLayouts.ldiv!(svd(A.A), copy(c)) ArrayLayouts.ldiv!(similar(c), svd(A.A), c) A \ c
152+
@test ArrayLayouts.ldiv!(svd(A.A), Vector(c)) ArrayLayouts.ldiv!(similar(c), svd(A.A), c) A \ c
155153
if VERSION v"1.8"
156154
@test ArrayLayouts.ldiv!(similar(c), transpose(lu(A.A)), copy(c)) A'\c
157155
end
@@ -213,8 +211,8 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
213211
@test B == Ones(5,5)*A + 2.0Bin
214212
end
215213

216-
C = MyMatrix([1 2; 3 4])
217-
@test sprint(show, "text/plain", C) == "2×2 $MyMatrix:\n 1.0 2.0\n 3.0 4.0"
214+
C = MyMatrix(Float64[1 2; 3 4])
215+
@test sprint(show, "text/plain", C) == "$(summary(C)):\n 1.0 2.0\n 3.0 4.0"
218216

219217
@testset "layoutldiv" begin
220218
A = MyMatrix(randn(5,5))
@@ -337,6 +335,33 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
337335
@test\D \Matrix(D)
338336
@test D\\D
339337
@test/D /Matrix(D)
338+
339+
@testset "Diagonal * Bidiagonal/Tridiagonal with structured diags" begin
340+
n = size(D,1)
341+
B = Bidiagonal(map(MyVector, (rand(n), rand(n-1)))..., :U)
342+
MB = MyMatrix(B)
343+
S = SymTridiagonal(map(MyVector, (rand(n), rand(n-1)))...)
344+
MS = MyMatrix(S)
345+
T = Tridiagonal(map(MyVector, (rand(n-1), rand(n), rand(n-1)))...)
346+
MT = MyMatrix(T)
347+
DA, BA, SA, TA = map(Array, (D, B, S, T))
348+
if VERSION >= v"1.11"
349+
@test D * B DA * BA
350+
@test B * D BA * DA
351+
@test D * MB DA * BA
352+
@test MB * D BA * DA
353+
end
354+
if VERSION >= v"1.12.0-DEV.824"
355+
@test D * S DA * SA
356+
@test D * MS DA * SA
357+
@test D * T DA * TA
358+
@test D * MT DA * TA
359+
@test S * D SA * DA
360+
@test MS * D SA * DA
361+
@test T * D TA * DA
362+
@test MT * D TA * DA
363+
end
364+
end
340365
end
341366

342367
@testset "Adj/Trans" begin
@@ -400,7 +425,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
400425

401426
@testset "dot" begin
402427
a = MyVector(randn(5))
403-
@test dot(a, Zeros(5)) dot(Zeros(5), a) 0.0
428+
@test dot(a, Zeros(5)) dot(Zeros(5), a) == 0.0
404429
end
405430

406431
@testset "layout_getindex scalar" begin

0 commit comments

Comments
 (0)