Description
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
- Test that the bug appears on the current version of the main branch. Make sure to include the commit hash of the commit you checked out. 6ca8f8e
- Check that the issue hasn't already been reported, by checking the currently open issues.
- If there are steps to reproduce the problem, make sure to write them down below.
- If relevant, please include the ONNX files, which were created directly before and/or after the bug.
Quick summary
While working on our characterization of the transformer data-flow we encountered some discrepancies when validating against the QONNX inference_cost
estimations of the MatMul operator within the attention mechanism. We are not entirely sure whether this is indeed a bug on the QONNX side or still some confusion/error on our side. Thus we would like to start a discussion to understand this issue.
Details
Multi-Head Scaled Dot-Product Attention involves two consecutive MatMul operations where both inputs dynamically depend on the model inputs. The heads are independent of each other and typically treated in a way similar to a batch dimension. Our cost model assumes HxTxTxd MAC operations for each of the two MatMuls, i.e. H heads each producing a TxT attention matrix (T is the sequence length) where each element is the result of a d-dimensional dot-product. However, the QONNX analysis function inference_cost_matmul
seems to be off by an additional factor of H (i.e. HxHxTxTxd), indicating the heads are not treated like a batch dimension.
My suspicion is further raised by the following lines from the QONNX inference_cost_matmul
function:
# exclude common dim (last axis) from one side to avoid duplication
n_macs = np.prod(i_shape[:-1]) * np.prod(w_shape)
Is this actually always the case? At least for the model graph I have attached it seems like the last axis is not the common dimension.
In the following, I provide a minimal working example of a scaled dot-product attention in isolation in PyTorch exporting to an ONNX graph. I have also attached the already preprocessed graph which in particular already includes the InferShapes
transform. Note that running the qonnx.util.inference_cost
script on the PyTorch ONNX export breaks at the FoldConstants
transform due to IndexError
which is probably unrelated and should be investigated separately (I have "fixed" it by removing that transformation step for now).
Steps to Reproduce
The following code produces a minimal example of scaled dot-product attention and exports to ONNX.
import torch
# Minimal working example of the Scaled Dot-Product Attention mechanism
class ScaleDotProductAttention(torch.nn.Module):
# Initializes the module parameters
def __init__(self, num_heads):
# Initialize the PyTorch base Module
super().__init__()
# Set the number of attention heads
self.num_heads = num_heads
# Forward pass computing scaled dot-product attention between q, k and v
def forward(self, q, k, v):
# Assume the most simple case of q, k and v all having the same
# dimensions
assert q.shape == k.shape == v.shape, \
"Q, K and V must have the same shape"
# Embedding dimension must be divisible by number of heads
assert q.shape[-1] % self.num_heads == 0, \
f"Dimensions must be divisible by heads ({self.num_heads})"
# Assume sequence first layout and get the sizes per axis
s, b, d = q.shape
# Number of heads and dimension per head
n_head, d_head = self.num_heads, d // self.num_heads
# Reshape tensors to treat the heads like batch dimensions
q = q.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
k = k.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
v = v.reshape(s, b, n_head, d_head).reshape(s, b * n_head, d_head)
# Compute the not-yet-normalized attentions matrices for each head.
# Note: permute brings batch x heads to front and transposes k
a = torch.matmul(q.permute(1, 0, 2), k.permute(1, 2, 0))
# Scale and normalize the attention matrix
a = torch.softmax(a * (d_head ** -0.5), dim=-1)
# Apply the attention matrices to the value projection
# Note: Second permute brings sequence dimension back to front
o = torch.matmul(a, v.permute(1, 0, 2)).permute(1, 0, 2)
# Reshape heads into feature dimension
o = o.reshape(s, b, n_head, d_head).reshape(s, b, n_head * d_head)
# Return the scaled dot-product attention output
return o
# Script entrypoint
if __name__ == '__main__':
# Instantiate a scale dot-product attention with 4 attention heads
sdp = ScaleDotProductAttention(num_heads=4)
# Generate random query, key and value tensors
# Note: Sequence of length 64, single instance batch, 128 dim embeddings
q, k, v = torch.randn(3, 64, 1, 128)
# Export the attention module to ONNX
torch.onnx.export(sdp, args=(q, k, v), f='sdp.onnx')
Get MAC operation counts by running
python -m qonnx.util.inference_cost sdp.onnx
Outputs something like
{'op_mac_FLOAT32_FLOAT32': 4194304.0, 'mem_w_FLOAT32': 0.0, 'mem_o_FLOAT32': 24576.0, 'unsupported': "{'Softmax', 'Pow', 'Constant'}", 'discount_sparsity': True, 'total_bops': 4294967296.0, 'total_mem_w_bits': 0.0, 'total_mem_o_bits': 786432.0}
Expected behavior
According to our cost model, the MAC count should be 2x HxTxTxd, which for the given example model is 2x 4x64x64x32 = 1048576.
Actual behavior
The MAC count is reported as 4194304, which is 4x (Hx) our expectation, indicating a cost function of 2x HxHxTxTxd.