Skip to content

Commit 20b0f6c

Browse files
authored
Store if C isa Zeros in a MulAdd (#201)
* Store if C isa Zeros in MulAdd * Bump version to v1.6.0 * Rename field to Czero * Update comment
1 parent d542ead commit 20b0f6c

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <solver@mac.com>"]
4-
version = "1.5.3"
4+
version = "1.6.0"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/muladd.jl

+15-10
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
### This support BLAS style multiplication
2-
# α * A * B + β C
2+
# A * B * α + C * β
33
# but avoids the broadcast machinery
44

5-
# Lazy representation of α*A*B + β*C
5+
# Lazy representation of A*B + C*β
66
struct MulAdd{StyleA, StyleB, StyleC, T, AA, BB, CC}
77
α::T
88
A::AA
99
B::BB
1010
β::T
1111
C::CC
12+
Czero::Bool # this flag indicates whether C isa Zeros, or a copy of one
13+
# the idea is that if Czero == true, then downstream packages don't need to
14+
# fill C with zero before performing the muladd
1215
end
1316

14-
@inline MulAdd{StyleA,StyleB,StyleC}::T, A::AA, B::BB, β::T, C::CC) where {StyleA,StyleB,StyleC,T,AA,BB,CC} =
15-
MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C)
17+
@inline function MulAdd{StyleA,StyleB,StyleC}::T, A::AA, B::BB, β::T, C::CC;
18+
Czero = C isa Zeros) where {StyleA,StyleB,StyleC,T,AA,BB,CC}
19+
MulAdd{StyleA,StyleB,StyleC,T,AA,BB,CC}(α,A,B,β,C,Czero)
20+
end
1621

17-
@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C) where {StyleA,StyleB,StyleC}
22+
@inline function MulAdd{StyleA,StyleB,StyleC}(αT, A, B, βV, C; kw...) where {StyleA,StyleB,StyleC}
1823
α,β = promote(αT,βV)
19-
MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C)
24+
MulAdd{StyleA,StyleB,StyleC}(α, A, B, β, C; kw...)
2025
end
2126

22-
@inline MulAdd(α, A::AA, B::BB, β, C::CC) where {AA,BB,CC} =
23-
MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C)
27+
@inline MulAdd(α, A::AA, B::BB, β, C::CC; kw...) where {AA,BB,CC} =
28+
MulAdd{typeof(MemoryLayout(AA)), typeof(MemoryLayout(BB)), typeof(MemoryLayout(CC))}(α, A, B, β, C; kw...)
2429

2530
MulAdd(A, B) = MulAdd(Mul(A, B))
2631
function MulAdd(M::Mul)
@@ -67,15 +72,15 @@ const BlasMatMulVecAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB
6772
const BlasMatMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractMatrix{T},<:AbstractMatrix{T},<:AbstractMatrix{T}}
6873
const BlasVecMulMatAdd{StyleA,StyleB,StyleC,T<:BlasFloat} = MulAdd{StyleA,StyleB,StyleC,T,<:AbstractVector{T},<:AbstractMatrix{T},<:AbstractMatrix{T}}
6974

70-
muladd!(α, A, B, β, C) = materialize!(MulAdd(α, A, B, β, C))
75+
muladd!(α, A, B, β, C; kw...) = materialize!(MulAdd(α, A, B, β, C; kw...))
7176
materialize(M::MulAdd) = copy(instantiate(M))
7277
copy(M::MulAdd) = copyto!(similar(M), M)
7378

7479
_fill_copyto!(dest, C) = copyto!(dest, C)
7580
_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload
7681

7782
@inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T =
78-
muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C))
83+
muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C); Czero = M.Czero)
7984

8085
# Modified from LinearAlgebra._generic_matmatmul!
8186
const tilebufsize = 10800 # Approximately 32k/3

0 commit comments

Comments
 (0)