Skip to content

Commit 9afd0c8

Browse files
authored
Merge pull request #4814 from Mousius/gemv-proxy
Forward GEMM to GEMV when one argument is actually a vector
2 parents edbf093 + ba2e989 commit 9afd0c8

File tree

6 files changed

+141
-36
lines changed

6 files changed

+141
-36
lines changed

Makefile.system

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,18 @@ endif
274274
ifeq ($(ARCH), loongarch64)
275275
SMALL_MATRIX_OPT = 1
276276
endif
277+
ifeq ($(ARCH), arm64)
278+
GEMM_GEMV_FORWARD = 1
279+
endif
280+
277281
ifeq ($(SMALL_MATRIX_OPT), 1)
278282
CCOMMON_OPT += -DSMALL_MATRIX_OPT
279283
endif
284+
ifeq ($(GEMM_GEMV_FORWARD), 1)
285+
ifneq ($(ONLY_CBLAS), 1)
286+
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
287+
endif
288+
endif
280289

281290
# This operation is expensive, so execution should be once.
282291
ifndef GOTOBLAS_MAKEFILE

cmake/system.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ endif ()
391391
if (X86_64 OR ${CORE} STREQUAL POWER10)
392392
set(SMALL_MATRIX_OPT TRUE)
393393
endif ()
394+
if (ARM64)
395+
set(GEMM_GEMV_FORWARD TRUE)
396+
endif ()
397+
398+
if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
399+
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
400+
endif ()
394401
if (SMALL_MATRIX_OPT)
395402
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
396403
endif ()

interface/gemm.c

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/*********************************************************************/
2+
/* Copyright 2024 The OpenBLAS Project */
23
/* Copyright 2009, 2010 The University of Texas at Austin. */
34
/* All rights reserved. */
45
/* */
@@ -47,12 +48,16 @@
4748
#define SMP_THRESHOLD_MIN 65536.0
4849
#ifdef XDOUBLE
4950
#define ERROR_NAME "QGEMM "
51+
#define GEMV BLASFUNC(qgemv)
5052
#elif defined(DOUBLE)
5153
#define ERROR_NAME "DGEMM "
54+
#define GEMV BLASFUNC(dgemv)
5255
#elif defined(BFLOAT16)
5356
#define ERROR_NAME "SBGEMM "
57+
#define GEMV BLASFUNC(sbgemv)
5458
#else
5559
#define ERROR_NAME "SGEMM "
60+
#define GEMV BLASFUNC(sgemv)
5661
#endif
5762
#else
5863
#define SMP_THRESHOLD_MIN 8192.0
@@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
493498
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
494499
#endif
495500

501+
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX)
502+
// Check if we can convert GEMM -> GEMV
503+
if (args.k != 0) {
504+
if (args.n == 1) {
505+
blasint inc_x = 1;
506+
blasint inc_y = 1;
507+
// These were passed in as blasint, but the struct translates them to blaslong
508+
blasint m = args.m;
509+
blasint n = args.k;
510+
blasint lda = args.lda;
511+
// Create new transpose parameters
512+
char NT = 'N';
513+
if (transa & 1) {
514+
NT = 'T';
515+
m = args.k;
516+
n = args.m;
517+
}
518+
if (transb & 1) {
519+
inc_x = args.ldb;
520+
}
521+
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
522+
return;
523+
}
524+
if (args.m == 1) {
525+
blasint inc_x = args.lda;
526+
blasint inc_y = args.ldc;
527+
// These were passed in as blasint, but the struct translates them to blaslong
528+
blasint m = args.k;
529+
blasint n = args.n;
530+
blasint ldb = args.ldb;
531+
// Create new transpose parameters
532+
char NT = 'T';
533+
if (transa & 1) {
534+
inc_x = 1;
535+
}
536+
if (transb & 1) {
537+
NT = 'N';
538+
m = args.n;
539+
n = args.k;
540+
}
541+
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
542+
return;
543+
}
544+
}
545+
#endif
546+
496547
IDEBUG_START;
497548

498549
FUNCTION_PROFILE_START();

kernel/arm64/KERNEL.NEOVERSEV1

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
2+
3+
SGEMVTKERNEL = gemv_t_sve.c
4+
DGEMVTKERNEL = gemv_t_sve.c

kernel/arm64/gemv_t.S

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
Copyright (c) 2015, The OpenBLAS Project
2+
Copyright (c) 2015, 2024 The OpenBLAS Project
33
All rights reserved.
44
Redistribution and use in source and binary forms, with or without
55
modification, are permitted provided that the following conditions are
@@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
170170

171171
.macro KERNEL_F32_FINALIZE
172172
#if !defined(DOUBLE)
173-
fadd v1.4s, v1.4s, v2.4s
173+
// F8 only has 2 accumulators
174+
// so add into those pairs
174175
fadd v1.4s, v1.4s, v3.4s
175-
fadd v1.4s, v1.4s, v4.4s
176-
#else
177-
fadd v1.2d, v1.2d, v2.2d
178-
fadd v1.2d, v1.2d, v3.2d
179-
fadd v1.2d, v1.2d, v4.2d
176+
fadd v2.4s, v2.4s, v4.4s
180177
#endif
181178
.endm
182179

183-
.macro KERNEL_F4
180+
.macro KERNEL_F8
184181
#if !defined(DOUBLE)
185-
ld1 {v2.4s}, [A_PTR], #16
186-
ld1 {v3.4s}, [X_PTR], #16
187-
fmla v1.4s, v2.4s, v3.4s
188-
#else
189-
ld1 {v2.2d}, [A_PTR], #16
190-
ld1 {v3.2d}, [X_PTR], #16
191-
fmla v1.2d, v2.2d, v3.2d
192-
193-
ld1 {v4.2d}, [A_PTR], #16
194-
ld1 {v5.2d}, [X_PTR], #16
195-
fmla v1.2d, v4.2d, v5.2d
182+
ld1 {v13.4s, v14.4s}, [A_PTR], #32
183+
ld1 {v17.4s, v18.4s}, [X_PTR], #32
184+
fmla v1.4s, v13.4s, v17.4s
185+
fmla v2.4s, v14.4s, v18.4s
186+
#else
187+
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
188+
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
189+
fmla v1.2d, v13.2d, v17.2d
190+
fmla v2.2d, v14.2d, v18.2d
191+
fmla v3.2d, v15.2d, v19.2d
192+
fmla v4.2d, v16.2d, v20.2d
196193
#endif
197194
.endm
198195

199-
.macro KERNEL_F4_FINALIZE
196+
.macro KERNEL_F8_FINALIZE
200197
#if !defined(DOUBLE)
201-
ext v2.16b, v1.16b, v1.16b, #8
198+
// Take the top two elements of v1 and
199+
// put them into the first two lanes of v3
200+
ext v3.16b, v1.16b, v1.16b, #8
201+
fadd v1.2s, v1.2s, v3.2s
202+
ext v4.16b, v2.16b, v2.16b, #8
203+
fadd v2.2s, v2.2s, v4.2s
204+
// Final pair
202205
fadd v1.2s, v1.2s, v2.2s
203206
faddp TEMP, v1.2s
204207
#else
205208
faddp TEMP, v1.2d
209+
faddp TEMP1, v2.2d
210+
faddp TEMP2, v3.2d
211+
faddp TEMP3, v4.2d
212+
fadd TEMP, TEMP, TEMP1
213+
fadd TEMP2, TEMP2, TEMP3
214+
fadd TEMP, TEMP, TEMP2
206215
#endif
207216
.endm
208217

@@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
258267

259268
asr I, M, #5
260269
cmp I, xzr
261-
beq .Lgemv_t_kernel_F4
270+
beq .Lgemv_t_kernel_F8
262271

263272
.Lgemv_t_kernel_F320:
264273

@@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
269278

270279
KERNEL_F32_FINALIZE
271280

272-
.Lgemv_t_kernel_F4:
281+
.Lgemv_t_kernel_F8:
273282
ands I, M, #31
274-
asr I, I, #2
283+
asr I, I, #3
275284
cmp I, xzr
276285
beq .Lgemv_t_kernel_F1
277286

278-
.Lgemv_t_kernel_F40:
287+
.Lgemv_t_kernel_F80:
279288

280-
KERNEL_F4
289+
KERNEL_F8
281290

282291
subs I, I, #1
283-
bne .Lgemv_t_kernel_F40
292+
bne .Lgemv_t_kernel_F80
284293

285294
.Lgemv_t_kernel_F1:
286295

287-
KERNEL_F4_FINALIZE
296+
KERNEL_F8_FINALIZE
288297

289-
ands I, M, #3
298+
ands I, M, #7
290299
ble .Lgemv_t_kernel_F_END
291300

292301
.Lgemv_t_kernel_F10:

kernel/arm64/gemv_t_sve.c

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
5959
a_ptr = a;
6060

6161
if (inc_x == 1) {
62+
svbool_t pg_true = SV_TRUE();
6263
uint64_t sve_size = SV_COUNT();
64+
uint64_t sve_size2 = sve_size * 2;
65+
BLASLONG m1 = m & -sve_size;
66+
BLASLONG m2 = m & -sve_size2;
67+
6368
for (j = 0; j < n; j++) {
69+
BLASLONG i = 0;
70+
71+
SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
72+
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
73+
for (; i < m2; i += sve_size2) {
74+
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
75+
SV_TYPE x_vec0 = svld1(pg_true, x + i);
76+
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
77+
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
78+
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
79+
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
80+
}
81+
82+
SV_TYPE temp_vec_v1 = SV_DUP(0.0);
83+
for (; i < m1; i += sve_size) {
84+
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
85+
SV_TYPE x_vec0 = svld1(pg_true, x + i);
86+
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
87+
}
88+
6489
SV_TYPE temp_vec = SV_DUP(0.0);
65-
i = 0;
66-
svbool_t pg = SV_WHILE(i, m);
67-
while (svptest_any(SV_TRUE(), pg)) {
90+
for (; i < m; i += sve_size) {
91+
svbool_t pg = SV_WHILE(i, m);
6892
SV_TYPE a_vec = svld1(pg, a_ptr + i);
6993
SV_TYPE x_vec = svld1(pg, x + i);
7094
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
71-
i += sve_size;
72-
pg = SV_WHILE(i, m);
7395
}
74-
temp = svaddv(SV_TRUE(), temp_vec);
75-
y[iy] += alpha * temp;
96+
97+
y[iy] += alpha * (
98+
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
99+
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
100+
);
101+
76102
iy += inc_y;
77103
a_ptr += lda;
78104
}

0 commit comments

Comments
 (0)