Skip to content

[WebGPU] Optimize GEMM with vec4 #24478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

xiaofeihan1
Copy link
Contributor

@xiaofeihan1 xiaofeihan1 commented Apr 21, 2025

Description

In this PR, we use vec4 to optimize GEMM when colums of A and B can be divided by 4, or use previous shader.
I will add u32/vec2 implementation in the future, and we will only keep one shader at that time.

Perf comparison

I run customized model only include GEMM(M = N = K = 1024) with nodejs on M2/M3 Max. Roughly 20% increase.

!transA&&!transB transA transB transA&&transB
M2 9.36->7.41 9.45->7.54 11.21->8.19 9.66->8.37
M3 max 8.07->6.99 7.54->6.53 8.42->5.89 5.47->5.29

@fs-eire
Copy link
Contributor

fs-eire commented Apr 22, 2025

Is there a way to reuse the implementation of MatMul? My understanding is that there are some kind of duplication between GEMM and MatMul, and it would be great if we can reuse the shared code

@xiaofeihan1
Copy link
Contributor Author

xiaofeihan1 commented Apr 22, 2025

Is there a way to reuse the implementation of MatMul? My understanding is that there are some kind of duplication between GEMM and MatMul, and it would be great if we can reuse the shared code

Thanks for the callout. That's what I'm gonna do next.
For current PR, I want to push forward to support vec4 for GEMM. I will take the refactor work in future PRs because it also require some effort to consider(e.g.There are also some differences between gemm and matmul, e.g. the latter supports batch size, the former supports transpose, etc). WDYT?

@fs-eire
Copy link
Contributor

fs-eire commented Apr 22, 2025

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@xiaofeihan1 xiaofeihan1 requested a review from qjia7 April 24, 2025 03:00
@fs-eire
Copy link
Contributor

fs-eire commented Apr 24, 2025

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@xiaofeihan1 xiaofeihan1 requested a review from qjia7 April 25, 2025 08:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants