-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[WebGPU] Optimize GEMM with vec4 #24478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WebGPU] Optimize GEMM with vec4 #24478
Conversation
Is there a way to reuse the implementation of |
Thanks for the callout. That's what I'm gonna do next. |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Description
In this PR, we use vec4 to optimize GEMM when colums of A and B can be divided by 4, or use previous shader.
I will add u32/vec2 implementation in the future, and we will only keep one shader at that time.
Perf comparison
I run customized model only include GEMM(M = N = K = 1024) with nodejs on M2/M3 Max. Roughly 20% increase.