|
1 | 1 | ### This support BLAS style multiplication
|
2 |
| -# α * A * B + β C |
| 2 | +# A * B * α + C * β |
3 | 3 | # but avoids the broadcast machinery
|
4 | 4 |
|
5 |
| -# Lazy representation of α*A*B + β*C |
| 5 | +# Lazy representation of A*B*α + C*β |
6 | 6 | struct MulAdd{StyleA, StyleB, StyleC, T, AA, BB, CC}
|
7 | 7 | α::T
|
8 | 8 | A::AA
|
9 | 9 | B::BB
|
10 | 10 | β::T
|
11 | 11 | C::CC
|
| 12 | + Czero::Bool # this flag indicates whether C isa Zeros, or a copy of one |
| 13 | + # the idea is that if Czero == true, then downstream packages don't need to |
| 14 | + # fill C with zero before performing the muladd |
12 | 15 | end
|
13 | 16 |
|
14 |
| -@inline MulAdd{StyleA,StyleB,StyleC}(α::T, A::AA, B::BB, β::T, C::CC) where {StyleA,StyleB,StyleC,T,AA,BB,CC} = |
15 |
| - MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C) |
| 17 | +@inline function MulAdd{StyleA,StyleB,StyleC}(α::T, A::AA, B::BB, β::T, C::CC; |
| 18 | + Czero = C isa Zeros) where {StyleA,StyleB,StyleC,T,AA,BB,CC} |
| 19 | + MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C,Czero) |
| 20 | +end |
16 | 21 |
|
17 |
| -@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C) where {StyleA,StyleB,StyleC} |
| 22 | +@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C; kw...) where {StyleA,StyleB,StyleC} |
18 | 23 | α,β = promote(αT,βV)
|
19 |
| - MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C) |
| 24 | + MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C; kw...) |
20 | 25 | end
|
21 | 26 |
|
22 |
| -@inline MulAdd(α, A::AA, B::BB, β, C::CC) where {AA,BB,CC} = |
23 |
| - MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C) |
| 27 | +@inline MulAdd(α, A::AA, B::BB, β, C::CC; kw...) where {AA,BB,CC} = |
| 28 | + MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C; kw...) |
24 | 29 |
|
25 | 30 | MulAdd(A, B) = MulAdd(Mul(A, B))
|
26 | 31 | function MulAdd(M::Mul)
|
@@ -67,15 +72,15 @@ const BlasMatMulVecAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB
|
67 | 72 | const BlasMatMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractMatrix{T},<:AbstractMatrix{T},<:AbstractMatrix{T}}
|
68 | 73 | const BlasVecMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractVector{T},<:AbstractMatrix{T},<:AbstractMatrix{T}}
|
69 | 74 |
|
70 |
| -muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C)) |
| 75 | +muladd!(α, A, B, β, C; kw...) = materialize!(MulAdd(α, A, B, β, C; kw...)) |
71 | 76 | materialize(M::MulAdd) = copy(instantiate(M))
|
72 | 77 | copy(M::MulAdd) = copyto!(similar(M), M)
|
73 | 78 |
|
74 | 79 | _fill_copyto!(dest, C) = copyto!(dest, C)
|
75 | 80 | _fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload
|
76 | 81 |
|
77 | 82 | @inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T =
|
78 |
| - muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C)) |
| 83 | + muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C); Czero = M.Czero) |
79 | 84 |
|
80 | 85 | # Modified from LinearAlgebra._generic_matmatmul!
|
81 | 86 | const tilebufsize = 10800 # Approximately 32k/3
|
|
0 commit comments