Skip to content

[MLIR] [Vector] ConstantFold MultiDReduction #122450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
139 changes: 138 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -463,10 +462,148 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}

/// TODO: Move to APFloat/APInt.
/// Computes the result of reducing a constant vector where the accumulator
/// value, `acc`, is also constant. `times` is the number of times the operation
/// is applied.
static APFloat computePowerOf(const APFloat &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APFloat::getOne(a.getSemantics());
}
APFloat acc = a;
int64_t remainingExponent = exponent;
while (remainingExponent > 1) {
if (remainingExponent % 2 == 0) {
acc = acc * acc;
remainingExponent /= 2;
} else {
acc = acc * a;
--remainingExponent;
}
}
return acc;
};

static APInt computePowerOf(const APInt &a, int64_t exponent) {
assert(exponent >= 0 && "negative exponents not supported.");
if (exponent == 0) {
return APInt(a.getBitWidth(), 1);
}
APInt acc = a;
int64_t remainingExponent = exponent;
while (remainingExponent > 1) {
if (remainingExponent % 2 == 0) {
acc = acc * acc;
remainingExponent /= 2;
} else {
acc = acc * a;
remainingExponent--;
}
}
return acc;
};

static OpFoldResult computeConstantReduction(FloatAttr src, FloatAttr acc,
int64_t times, CombiningKind kind,
ShapedType dstType) {
APFloat srcVal = src.getValue();
APFloat accVal = acc.getValue();
switch (kind) {
case CombiningKind::ADD: {
APFloat n = APFloat(srcVal.getSemantics());
n.convertFromAPInt(APInt(64, times, true), true,
APFloat::rmNearestTiesToEven);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a problem is the user expects a non-default rounding mode? We have been adding FMF and RM bottom-up in the IR but it's lacking at vector level so I'm wondering if this would lead to an unexpected outcome. Perhaps @chelini, @kuhar could provide some feedback?
Worse case, I guess we could enable this folder under a flag...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. How does the user currently get non-default rounding mode? & it does not look like this is the only fold that would be affected. In Arith I see a TODO: https://mlir.llvm.org/docs/Dialects/ArithOps/#arithaddf-arithaddfop for adding optional attributes to specify that. So the issue of rounding mode seems a separate problem than constant folding and once support for rounding mode is added it can be added here, among with other affected folds, as well.

Copy link
Member

@kuhar kuhar Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For integer types, this should be clearly allowed.
For float, everything else being fine, I think in the absence of any rounding modes, this should be permissible as well.

My worry would be that the reduction order is inherently unspecified for this op, and constant folding may produce a different order than runtime evaluation. In the current form, I can see that this is only applied in the splat case, which I think would only produce different results if the runtime implementation ended up doing partial reductions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of partial reductions- we would have less accuracy than this right? So do we prefer, that the default behavior be that our program runs slower and give less accurate results?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we prefer, that the default behavior be that our program runs slower and give less accurate results?

I don't think this a fair way to put this. The concern is that the same computation gives two answers depending on the operands being compile-time constants or not, assuming reasonable lowering. I think the argument you are trying to make is that partial reductions should not be considered a valid lowering, and I'm honestly not sure either way because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu::SubgroupReduceOp and vector::ReductionOp don't fold when the argument is a constant. Maybe it would help us get unstack if we decide what to do with these two first? (This could be an RFC on discourse.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through other vector dialect ops and I don't see any other 'math' op folding. Because what you propose here is something new, I think it deserves an RFC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.

The concern is that this may make something that is runtime/hw-dependent (and it does not need to be), not hw dependent? Something that:

  1. Does not need to be rte/hw-dependent. It's a constant.
  2. There is no runtime cost to this, it would actually be faster to fold the constant.
  3. There is no precision cost to it, it would in fact be more accurate.
    In this case, it's Splat-Splat, but in general, partial reduction does not even return consistent result on the same hw because the order by which it is applied may change from run-to-run. How is that desirable? How is that even consistent? If partial ordering should be canon for reductions- in what order should it be applied then?
    I've seen discussion on this where it has been a decision between being fast or not hw-dependent. Being fast or being more accurate. This is neither.
    Why would we prefer to be needlessly hw-dependent, less accurate and slower? It's fine if some user somehow wants that- but why should it be the default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this discussion to an RFC. You are proposing much more aggressive folds than any existing ones in vector AFAICT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
}
case CombiningKind::MUL: {
return DenseElementsAttr::get(dstType,
{accVal * computePowerOf(srcVal, times)});
}
case CombiningKind::MINIMUMF:
return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
case CombiningKind::MAXIMUMF:
return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
case CombiningKind::MINNUMF:
return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
case CombiningKind::MAXNUMF:
return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
default:
return {};
}
}

static OpFoldResult computeConstantReduction(IntegerAttr src, IntegerAttr acc,
int64_t times, CombiningKind kind,
ShapedType dstType) {
APInt srcVal = src.getValue();
APInt accVal = acc.getValue();

switch (kind) {
case CombiningKind::ADD:
return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
case CombiningKind::MUL: {
return DenseElementsAttr::get(dstType,
{accVal * computePowerOf(srcVal, times)});
}
case CombiningKind::MINSI:
return DenseElementsAttr::get(dstType,
{accVal.slt(srcVal) ? accVal : srcVal});
case CombiningKind::MAXSI:
return DenseElementsAttr::get(dstType,
{accVal.ugt(srcVal) ? accVal : srcVal});
case CombiningKind::MINUI:
return DenseElementsAttr::get(dstType,
{accVal.ult(srcVal) ? accVal : srcVal});
case CombiningKind::MAXUI:
return DenseElementsAttr::get(dstType,
{accVal.ugt(srcVal) ? accVal : srcVal});
case CombiningKind::AND:
return DenseElementsAttr::get(dstType, {accVal & srcVal});
case CombiningKind::OR:
return DenseElementsAttr::get(dstType, {accVal | srcVal});
case CombiningKind::XOR:
return DenseElementsAttr::get(dstType,
{times & 0x1 ? accVal ^ srcVal : accVal});
default:
return {};
}
}

OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();

auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
if (!srcAttr || !accAttr || !srcAttr.isSplat() || !accAttr.isSplat())
return {};

ArrayRef<int64_t> reductionDims = getReductionDims();
auto srcType = cast<ShapedType>(getSourceVectorType());
ArrayRef<int64_t> srcDims = srcType.getShape();

int64_t times = 1;
for (int64_t dim : reductionDims) {
times *= srcDims[dim];
}

CombiningKind kind = getKind();
auto dstType = cast<ShapedType>(getDestType());
Type dstEltType = dstType.getElementType();

if (mlir::dyn_cast_or_null<FloatType>(dstEltType)) {
return computeConstantReduction(srcAttr.getSplatValue<FloatAttr>(),
accAttr.getSplatValue<FloatAttr>(), times,
kind, dstType);
}
if (mlir::dyn_cast_or_null<IntegerType>(dstEltType)) {
return computeConstantReduction(srcAttr.getSplatValue<IntegerAttr>(),
accAttr.getSplatValue<IntegerAttr>(), times,
kind, dstType);
}

return {};
}

Expand Down
81 changes: 81 additions & 0 deletions mlir/test/Dialect/Vector/constant-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,84 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
return %2 : vector<4x4xf16>
}

// CHECK-LABEL: fold_multi_reduction_f32_add
func.func @fold_multi_reduction_f32_add() -> vector<1xf32> {
%acc = arith.constant dense<0.000000e+00> : vector<1xf32>
%0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
// CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
%1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xf32> to vector<1xf32>
return %1 : vector<1xf32>
}

// CHECK-LABEL: fold_multi_reduction_f32_mul
func.func @fold_multi_reduction_f32_mul() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
// CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
%1 = vector.multi_reduction <mul>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}

// CHECK-LABEL: fold_multi_reduction_f32_maximumf
func.func @fold_multi_reduction_f32_maximumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
// CHECK: %{{.*}} = arith.constant dense<2.000000e+00> : vector<1xf32>
%1 = vector.multi_reduction <maximumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}

// CHECK-LABEL: fold_multi_reduction_f32_minnumf
func.func @fold_multi_reduction_f32_minnumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<0xFFFFFFFF> : vector<1x2x2xf32>
// CHECK: %{{.*}} = arith.constant dense<1.000000e+00> : vector<1xf32>
%1 = vector.multi_reduction <minnumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}

// CHECK-LABEL: fold_multi_reduction_f32_minimumf
func.func @fold_multi_reduction_f32_minimumf() -> vector<1xf32> {
%acc = arith.constant dense<1.000000e+00> : vector<1xf32>
%0 = arith.constant dense<0xFFFFFFFF> : vector<1x2x2xf32>
// CHECK: %{{.*}} = arith.constant dense<0xFFFFFFFF> : vector<1xf32>
%1 = vector.multi_reduction <minimumf>, %0, %acc [1, 2] : vector<1x2x2xf32> to vector<1xf32>
return %1 : vector<1xf32>
}

// CHECK-LABEL: fold_multi_reduction_i32_add
func.func @fold_multi_reduction_i32_add() -> vector<1xi32> {
%acc = arith.constant dense<1> : vector<1xi32>
%0 = arith.constant dense<1> : vector<1x128x128xi32>
// CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
%1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xi32> to vector<1xi32>
return %1 : vector<1xi32>
}

// CHECK-LABEL: fold_multi_reduction_i32_xor_odd_num_elements
func.func @fold_multi_reduction_i32_xor_odd_num_elements() -> vector<1xi32> {
%acc = arith.constant dense<0xFF> : vector<1xi32>
%0 = arith.constant dense<0xA0A> : vector<1x3xi32>
// CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
%1 = vector.multi_reduction <xor>, %0, %acc [1] : vector<1x3xi32> to vector<1xi32>
return %1 : vector<1xi32>
}

// CHECK-LABEL: fold_multi_reduction_i32_xor_even_num_elements
func.func @fold_multi_reduction_i32_xor_even_num_elements() -> vector<1xi32> {
%acc = arith.constant dense<0xFF> : vector<1xi32>
%0 = arith.constant dense<0xA0A> : vector<1x4xi32>
// CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
%1 = vector.multi_reduction <xor>, %0, %acc [1] : vector<1x4xi32> to vector<1xi32>
return %1 : vector<1xi32>
}

// CHECK-LABEL: fold_multi_reduction_i64_add
func.func @fold_multi_reduction_i64_add() -> vector<1xi64> {
%acc = arith.constant dense<1> : vector<1xi64>
%0 = arith.constant dense<1> : vector<1x128x128xi64>
// CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
%1 = vector.multi_reduction <add>, %0, %acc [1, 2] : vector<1x128x128xi64> to vector<1xi64>
return %1 : vector<1xi64>
}
Loading