Skip to content

Use specialized intrinsics for dot4{I, U}8Packed on SPIR-V and HLSL #7574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Bottom level categories:

Naga now infers the correct binding layout when a resource appears only in an assignment to `_`. By @andyleiserson in [#7540](https://github.com/gfx-rs/wgpu/pull/7540).

- Add polyfills for `dot4U8Packed` and `dot4I8Packed` for all backends. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494).
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V and HSLS if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494) and [#7574](https://github.com/gfx-rs/wgpu/pull/7574).
- Add polyfilled `pack4x{I,U}8Clamped` built-ins to all backends and WGSL frontend. By @ErichDonGubler in [#7546](https://github.com/gfx-rs/wgpu/pull/7546).

#### DX12
Expand Down
65 changes: 40 additions & 25 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{
WrappedZeroValue,
},
storage::StoreValue,
BackendResult, Error, FragmentEntryPoint, Options,
BackendResult, Error, FragmentEntryPoint, Options, ShaderModel,
};
use crate::{
back::{self, Baked},
Expand Down Expand Up @@ -3751,33 +3751,48 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
fun @ (Function::Dot4I8Packed | Function::Dot4U8Packed) => {
let arg1 = arg1.unwrap();

write!(self.out, "dot(")?;
if self.options.shader_model >= ShaderModel::V6_4 {
// Intrinsics `dot4add_{i, u}8packed` are available in SM 6.4 and later.
let function_name = match fun {
Function::Dot4I8Packed => "dot4add_i8packed",
Function::Dot4U8Packed => "dot4add_u8packed",
_ => unreachable!(),
};
write!(self.out, "{function_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", 0)")?;
} else {
// Fall back to a polyfill as `dot4add_u8packed` is not available.
write!(self.out, "dot(")?;

if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24, ")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24, ")?;

if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
if matches!(fun, Function::Dot4U8Packed) {
write!(self.out, "u")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
write!(self.out, "int4(")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 8, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 16, ")?;
self.write_expr(module, arg1, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
Expand Down
126 changes: 78 additions & 48 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,59 +1143,89 @@ impl BlockContext<'_> {
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
// TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
let (extract_op, arg0_id, arg1_id) = match fun {
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
Mf::Dot4I8Packed => {
// Convert both packed arguments to signed integers so that we can apply the
// `BitFieldSExtract` operation on them in `write_dot_product` below.
let new_arg0_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg0_id,
arg0_id,
));
if self.writer.lang_version() >= (1, 6)
&& self
.writer
.require_all(&[
spirv::Capability::DotProduct,
spirv::Capability::DotProductInput4x8BitPacked,
])
.is_ok()
{
// Write optimized code using `PackedVectorFormat4x8Bit`.
self.writer.use_extension("SPV_KHR_integer_dot_product");

let op = match fun {
Mf::Dot4I8Packed => spirv::Op::SDot,
Mf::Dot4U8Packed => spirv::Op::UDot,
_ => unreachable!(),
};

let new_arg1_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg1_id,
arg1_id,
));
block.body.push(Instruction::ternary(
op,
result_type_id,
id,
arg0_id,
arg1_id,
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit as Word,
));
} else {
// Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
let (extract_op, arg0_id, arg1_id) = match fun {
Mf::Dot4U8Packed => (spirv::Op::BitFieldUExtract, arg0_id, arg1_id),
Mf::Dot4I8Packed => {
// Convert both packed arguments to signed integers so that we can apply the
// `BitFieldSExtract` operation on them in `write_dot_product` below.
let new_arg0_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg0_id,
arg0_id,
));

(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
}
_ => unreachable!(),
};
let new_arg1_id = self.gen_id();
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
new_arg1_id,
arg1_id,
));

let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
(spirv::Op::BitFieldSExtract, new_arg0_id, new_arg1_id)
}
_ => unreachable!(),
};

const VEC_LENGTH: u8 = 4;
let bit_shifts: [_; VEC_LENGTH as usize] = core::array::from_fn(|index| {
self.writer
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
});
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));

const VEC_LENGTH: u8 = 4;
let bit_shifts: [_; VEC_LENGTH as usize] =
core::array::from_fn(|index| {
self.writer
.get_constant_scalar(crate::Literal::U32(index as u32 * 8))
});

self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
VEC_LENGTH as Word,
block,
|result_id, composite_id, index| {
Instruction::ternary(
extract_op,
result_type_id,
result_id,
composite_id,
bit_shifts[index as usize],
eight,
)
},
);
}

self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
VEC_LENGTH as Word,
block,
|result_id, composite_id, index| {
Instruction::ternary(
extract_op,
result_type_id,
result_id,
composite_id,
bit_shifts[index as usize],
eight,
)
},
);
self.cached[expr_handle] = id;
return Ok(());
}
Expand Down
17 changes: 14 additions & 3 deletions naga/src/back/spv/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use alloc::format;
const GENERATOR: Word = 28;

impl PhysicalLayout {
pub(super) const fn new(version: Word) -> Self {
pub(super) const fn new(major_version: u8, minor_version: u8) -> Self {
let version = ((major_version as u32) << 16) | ((minor_version as u32) << 8);
PhysicalLayout {
magic_number: MAGIC_NUMBER,
version,
Expand All @@ -29,6 +30,13 @@ impl PhysicalLayout {
sink.extend(iter::once(self.bound));
sink.extend(iter::once(self.instruction_schema));
}

/// Returns `(major, minor)`.
pub(super) const fn lang_version(&self) -> (u8, u8) {
let major = (self.version >> 16) as u8;
let minor = (self.version >> 8) as u8;
(major, minor)
}
}

impl super::recyclable::Recyclable for PhysicalLayout {
Expand Down Expand Up @@ -150,10 +158,13 @@ impl Instruction {
#[test]
fn test_physical_layout_in_words() {
let bound = 5;
let version = 0x10203;

// The least and most significant bytes of `version` must both be zero
// according to the SPIR-V spec.
let version = 0x0001_0200;

let mut output = vec![];
let mut layout = PhysicalLayout::new(version);
let mut layout = PhysicalLayout::new(1, 2);
layout.bound = bound;

layout.in_words(&mut output);
Expand Down
45 changes: 43 additions & 2 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ impl Writer {
if major != 1 {
return Err(Error::UnsupportedVersion(major, minor));
}
let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);

let mut capabilities_used = crate::FastIndexSet::default();
capabilities_used.insert(spirv::Capability::Shader);
Expand All @@ -70,7 +69,7 @@ impl Writer {
let void_type = id_gen.next();

Ok(Writer {
physical_layout: PhysicalLayout::new(raw_version),
physical_layout: PhysicalLayout::new(major, minor),
logical_layout: LogicalLayout::default(),
id_gen,
capabilities_available: options.capabilities.clone(),
Expand Down Expand Up @@ -99,6 +98,11 @@ impl Writer {
})
}

/// Returns `(major, minor)` of the SPIR-V language version.
pub const fn lang_version(&self) -> (u8, u8) {
self.physical_layout.lang_version()
}

/// Reset `Writer` to its initial state, retaining any allocations.
///
/// Why not just implement `Recyclable` for `Writer`? By design,
Expand Down Expand Up @@ -202,6 +206,43 @@ impl Writer {
}
}

/// Indicate that the code requires all of the listed capabilities.
///
/// If all entries of `capabilities` appear in the available capabilities
/// specified in the [`Options`] from which this `Writer` was created
/// (including the case where [`Options::capabilities`] is `None`), add
/// them all to this `Writer`'s [`capabilities_used`] table, and return
/// `Ok(())`. If at least one of the listed capabilities is not available,
/// do not add anything to the `capabilities_used` table, and return the
/// first unavailable requested capability, wrapped in `Err()`.
///
/// This method is does not return an [`enum@Error`] in case of failure
/// because it may be used in cases where the caller can recover (e.g.,
/// with a polyfill) if the requested capabilities are not available. In
/// this case, it would be unnecessary work to find *all* the unavailable
/// requested capabilities, and to allocate a `Vec` for them, just so we
/// could return an [`Error::MissingCapabilities`]).
///
/// [`capabilities_used`]: Writer::capabilities_used
pub(super) fn require_all(
&mut self,
capabilities: &[spirv::Capability],
) -> Result<(), spirv::Capability> {
if let Some(ref available) = self.capabilities_available {
for requested in capabilities {
if !available.contains(requested) {
return Err(*requested);
}
}
}

for requested in capabilities {
self.capabilities_used.insert(*requested);
}

Ok(())
}

/// Indicate that the code uses the given extension.
pub(super) fn use_extension(&mut self, extension: &'static str) {
self.extensions_used.insert(extension);
Expand Down
11 changes: 11 additions & 0 deletions naga/tests/in/wgsl/functions-optimized.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Explicitly turn on optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.

targets = "SPIRV | HLSL"

[spv]
capabilities = ["DotProduct", "DotProductInput4x8BitPacked"]
version = [1, 6]

[hlsl]
shader_model = "V6_4"
19 changes: 19 additions & 0 deletions naga/tests/in/wgsl/functions-optimized.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
fn test_packed_integer_dot_product() -> u32 {
let a_5 = 1u;
let b_5 = 2u;
let c_5: i32 = dot4I8Packed(a_5, b_5);

let a_6 = 3u;
let b_6 = 4u;
let c_6: u32 = dot4U8Packed(a_6, b_6);

// test baking of arguments
let c_7: i32 = dot4I8Packed(5u + c_6, 6u + c_6);
let c_8: u32 = dot4U8Packed(7u + c_6, 8u + c_6);
return c_8;
}

@compute @workgroup_size(1)
fn main() {
let c = test_packed_integer_dot_product();
}
13 changes: 13 additions & 0 deletions naga/tests/in/wgsl/functions-unoptimized.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.

targets = "SPIRV | HLSL"

[spv]
# Provide some unrelated capability because an empty list of capabilities would
# get mapped to `None`, which would then be interpreted as "all capabilities
# are available".
capabilities = ["Matrix"]

[hlsl]
shader_model = "V6_3"
Loading