Skip to content

Commit ba2e989

Browse files
committed
Add accumulators to AArch64 GEMV Kernels
This helps to reduce values going missing as we accumulate.
1 parent b26424c commit ba2e989

File tree

3 files changed

+74
-36
lines changed

3 files changed

+74
-36
lines changed

kernel/arm64/KERNEL.NEOVERSEV1

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
2+
3+
SGEMVTKERNEL = gemv_t_sve.c
4+
DGEMVTKERNEL = gemv_t_sve.c

kernel/arm64/gemv_t.S

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
Copyright (c) 2015, The OpenBLAS Project
2+
Copyright (c) 2015, 2024 The OpenBLAS Project
33
All rights reserved.
44
Redistribution and use in source and binary forms, with or without
55
modification, are permitted provided that the following conditions are
@@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
170170

171171
.macro KERNEL_F32_FINALIZE
172172
#if !defined(DOUBLE)
173-
fadd v1.4s, v1.4s, v2.4s
173+
// F8 only has 2 accumulators
174+
// so add into those pairs
174175
fadd v1.4s, v1.4s, v3.4s
175-
fadd v1.4s, v1.4s, v4.4s
176-
#else
177-
fadd v1.2d, v1.2d, v2.2d
178-
fadd v1.2d, v1.2d, v3.2d
179-
fadd v1.2d, v1.2d, v4.2d
176+
fadd v2.4s, v2.4s, v4.4s
180177
#endif
181178
.endm
182179

183-
.macro KERNEL_F4
180+
.macro KERNEL_F8
184181
#if !defined(DOUBLE)
185-
ld1 {v2.4s}, [A_PTR], #16
186-
ld1 {v3.4s}, [X_PTR], #16
187-
fmla v1.4s, v2.4s, v3.4s
188-
#else
189-
ld1 {v2.2d}, [A_PTR], #16
190-
ld1 {v3.2d}, [X_PTR], #16
191-
fmla v1.2d, v2.2d, v3.2d
192-
193-
ld1 {v4.2d}, [A_PTR], #16
194-
ld1 {v5.2d}, [X_PTR], #16
195-
fmla v1.2d, v4.2d, v5.2d
182+
ld1 {v13.4s, v14.4s}, [A_PTR], #32
183+
ld1 {v17.4s, v18.4s}, [X_PTR], #32
184+
fmla v1.4s, v13.4s, v17.4s
185+
fmla v2.4s, v14.4s, v18.4s
186+
#else
187+
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
188+
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
189+
fmla v1.2d, v13.2d, v17.2d
190+
fmla v2.2d, v14.2d, v18.2d
191+
fmla v3.2d, v15.2d, v19.2d
192+
fmla v4.2d, v16.2d, v20.2d
196193
#endif
197194
.endm
198195

199-
.macro KERNEL_F4_FINALIZE
196+
.macro KERNEL_F8_FINALIZE
200197
#if !defined(DOUBLE)
201-
ext v2.16b, v1.16b, v1.16b, #8
198+
// Take the top two elements of v1 and
199+
// put them into the first two lanes of v3
200+
ext v3.16b, v1.16b, v1.16b, #8
201+
fadd v1.2s, v1.2s, v3.2s
202+
ext v4.16b, v2.16b, v2.16b, #8
203+
fadd v2.2s, v2.2s, v4.2s
204+
// Final pair
202205
fadd v1.2s, v1.2s, v2.2s
203206
faddp TEMP, v1.2s
204207
#else
205208
faddp TEMP, v1.2d
209+
faddp TEMP1, v2.2d
210+
faddp TEMP2, v3.2d
211+
faddp TEMP3, v4.2d
212+
fadd TEMP, TEMP, TEMP1
213+
fadd TEMP2, TEMP2, TEMP3
214+
fadd TEMP, TEMP, TEMP2
206215
#endif
207216
.endm
208217

@@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
258267

259268
asr I, M, #5
260269
cmp I, xzr
261-
beq .Lgemv_t_kernel_F4
270+
beq .Lgemv_t_kernel_F8
262271

263272
.Lgemv_t_kernel_F320:
264273

@@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
269278

270279
KERNEL_F32_FINALIZE
271280

272-
.Lgemv_t_kernel_F4:
281+
.Lgemv_t_kernel_F8:
273282
ands I, M, #31
274-
asr I, I, #2
283+
asr I, I, #3
275284
cmp I, xzr
276285
beq .Lgemv_t_kernel_F1
277286

278-
.Lgemv_t_kernel_F40:
287+
.Lgemv_t_kernel_F80:
279288

280-
KERNEL_F4
289+
KERNEL_F8
281290

282291
subs I, I, #1
283-
bne .Lgemv_t_kernel_F40
292+
bne .Lgemv_t_kernel_F80
284293

285294
.Lgemv_t_kernel_F1:
286295

287-
KERNEL_F4_FINALIZE
296+
KERNEL_F8_FINALIZE
288297

289-
ands I, M, #3
298+
ands I, M, #7
290299
ble .Lgemv_t_kernel_F_END
291300

292301
.Lgemv_t_kernel_F10:

kernel/arm64/gemv_t_sve.c

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
5959
a_ptr = a;
6060

6161
if (inc_x == 1) {
62+
svbool_t pg_true = SV_TRUE();
6263
uint64_t sve_size = SV_COUNT();
64+
uint64_t sve_size2 = sve_size * 2;
65+
BLASLONG m1 = m & -sve_size;
66+
BLASLONG m2 = m & -sve_size2;
67+
6368
for (j = 0; j < n; j++) {
69+
BLASLONG i = 0;
70+
71+
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
72+
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
73+
for (; i < m2; i += sve_size2) {
74+
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
75+
SV_TYPE x_vec0 = svld1(pg_true, x + i);
76+
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
77+
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
78+
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
79+
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
80+
}
81+
82+
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
83+
for (; i < m1; i += sve_size) {
84+
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
85+
SV_TYPE x_vec0 = svld1(pg_true, x + i);
86+
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
87+
}
88+
6489
SV_TYPE temp_vec = SV_DUP(0.0);
65-
i = 0;
66-
svbool_t pg = SV_WHILE(i, m);
67-
while (svptest_any(SV_TRUE(), pg)) {
90+
for (; i < m; i += sve_size) {
91+
svbool_t pg = SV_WHILE(i, m);
6892
SV_TYPE a_vec = svld1(pg, a_ptr + i);
6993
SV_TYPE x_vec = svld1(pg, x + i);
7094
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
71-
i += sve_size;
72-
pg = SV_WHILE(i, m);
7395
}
74-
temp = svaddv(SV_TRUE(), temp_vec);
75-
y[iy] += alpha * temp;
96+
97+
y[iy] += alpha * (
98+
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
99+
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
100+
);
101+
76102
iy += inc_y;
77103
a_ptr += lda;
78104
}

0 commit comments

Comments
 (0)