Skip to content

[MLIR] [Vector] ConstantFold MultiDReduction #122450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

ImanHosseini
Copy link
Contributor

@ImanHosseini ImanHosseini commented Jan 10, 2025

If both source and acc are splat, constant fold the multi-reduction. @apaszke

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Iman Hosseini (ImanHosseini)

Changes

If both source and acc are splat, constant fold the multi-reduction.


Full diff: https://github.com/llvm/llvm-project/pull/122450.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+104)
  • (modified) mlir/test/Dialect/Vector/constant-fold.mlir (+54)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..a23d952c5760c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -21,11 +21,13 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
@@ -44,6 +46,7 @@
 #include "llvm/ADT/bit.h"
 
 #include <cassert>
+#include <cmath>
 #include <cstdint>
 #include <numeric>
 
@@ -463,10 +466,111 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   build(builder, result, kind, source, acc, reductionDims);
 }
 
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+                             ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APFloat srcVal = src.getValue();
+  APFloat accVal = acc.getValue();
+  switch (kind) {
+    case CombiningKind::ADD:{
+      APFloat n = APFloat(srcVal.getSemantics());
+      n.convertFromAPInt(APInt(64, times, true), true,
+                             APFloat::rmNearestTiesToEven);
+      return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+    }
+    case CombiningKind::MUL: {
+      APFloat result = accVal;
+      for (int i = 0; i < times; ++i) {
+        result = result * srcVal;
+      }
+      return DenseElementsAttr::get(dstType, {result});
+    }
+    case CombiningKind::MINIMUMF:
+      return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+    case CombiningKind::MAXIMUMF:
+      return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+    case CombiningKind::MINNUMF:
+      return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+    case CombiningKind::MAXNUMF:
+      return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+    default:
+      return {};
+  }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+                             CombiningKind kind, ShapedType dstType) {
+  APInt srcVal = src.getValue();
+  APInt accVal = acc.getValue();
+  switch (kind) {
+    case CombiningKind::ADD:
+      return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+    case CombiningKind::MUL: {
+      APInt result = accVal;
+      for (int i = 0; i < times; ++i) {
+        result *= srcVal;
+      }
+      return DenseElementsAttr::get(dstType, {result});
+    }
+    case CombiningKind::MINSI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.slt(srcVal) ? accVal : srcVal});
+    case CombiningKind::MAXSI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+    case CombiningKind::MINUI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ult(srcVal) ? accVal : srcVal});
+    case CombiningKind::MAXUI:
+      return DenseElementsAttr::get(
+          dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+    case CombiningKind::AND:
+      return DenseElementsAttr::get(dstType, {accVal & srcVal});
+    case CombiningKind::OR:
+      return DenseElementsAttr::get(dstType, {accVal | srcVal});
+    case CombiningKind::XOR:
+    return DenseElementsAttr::get(dstType,
+                                  {times & 0x1 ? accVal ^ srcVal : accVal});
+  default:
+    return {};
+  }
+}
+
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
   // Single parallel dim, this is a noop.
   if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
     return getSource();
+  auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
+  auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
+  if (!srcAttr || !accAttr)
+    return {};
+  if (!srcAttr.isSplat() || !accAttr.isSplat())
+    return {};
+  auto reductionDims = getReductionDims();
+  auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
+  auto srcDims = srcType.getShape();
+  int64_t times = 1;
+  for (auto dim : reductionDims) {
+    times *= srcDims[dim];
+  }
+  CombiningKind kind = getKind();
+  auto dstType = mlir::cast<ShapedType>(getDestType());
+  auto eltype = dstType.getElementType();
+  if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
+    return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
+                                      accAttr.getSplatValue<FloatAttr>(), times,
+                                      kind, dstType);
+  }
+  if (mlir::dyn_cast_or_null<IntegerType>(eltype)) {
+    return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
+                                        accAttr.getSplatValue<IntegerAttr>(),
+                                        times, kind, dstType);
+  }
   return {};
 }
 
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041bf..43c52b4b36ca53 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,57 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
   %2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
   return %2 : vector<4x4xf16>
 }
+
+// CHECK-LABEL: fold_multid_reduction_f32_add
+func.func @fold_multid_reduction_f32_add() -> vector<1xf32> {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32>
+  %0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
+  // CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
+  %1 = vector.multi_reduction <add>, %0, %cst_0 [1, 2] : vector<1x128x128xf32> to vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_f32_mul
+func.func @fold_multid_reduction_f32_mul() -> vector<1xf32> {
+  %cst_0 = arith.constant dense<1.000000e+00> : vector<1xf32>
+  %0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
+  // CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
+  %1 = vector.multi_reduction <mul>, %0, %cst_0 [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_add
+func.func @fold_multid_reduction_i32_add() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<1> : vector<1xi32>
+  %0 = arith.constant dense<1> : vector<1x128x128xi32>
+  // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
+  %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_odd
+func.func @fold_multid_reduction_i32_xor_odd() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+  %0 = arith.constant dense<0xA0A> : vector<1x3xi32>
+  // CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
+  %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x3xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_even
+func.func @fold_multid_reduction_i32_xor_even() -> vector<1xi32> {
+  %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+  %0 = arith.constant dense<0xA0A> : vector<1x4xi32>
+  // CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
+  %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x4xi32> to vector<1xi32>
+  return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i64_add
+func.func @fold_multid_reduction_i64_add() -> vector<1xi64> {
+  %cst_1 = arith.constant dense<1> : vector<1xi64>
+  %0 = arith.constant dense<1> : vector<1x128x128xi64>
+  // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
+  %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi64> to vector<1xi64>
+  return %1 : vector<1xi64>
+}
\ No newline at end of file

Copy link

github-actions bot commented Jan 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

I've left a few minor comments - please address those before landing. Btw, do you have commit access?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, do not land yet. One sec.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for jumping out of nowhere :). Removing the block but please wait for the ok from @chelini or @kuhar. I just want to make sure that we are not going into a rounding mode nightmare because of this folding.

case CombiningKind::ADD: {
APFloat n = APFloat(srcVal.getSemantics());
n.convertFromAPInt(APInt(64, times, true), true,
APFloat::rmNearestTiesToEven);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be a problem is the user expects a non-default rounding mode? We have been adding FMF and RM bottom-up in the IR but it's lacking at vector level so I'm wondering if this would lead to an unexpected outcome. Perhaps @chelini, @kuhar could provide some feedback?
Worse case, I guess we could enable this folder under a flag...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. How does the user currently get non-default rounding mode? & it does not look like this is the only fold that would be affected. In Arith I see a TODO: https://mlir.llvm.org/docs/Dialects/ArithOps/#arithaddf-arithaddfop for adding optional attributes to specify that. So the issue of rounding mode seems a separate problem than constant folding and once support for rounding mode is added it can be added here, among with other affected folds, as well.

Copy link
Member

@kuhar kuhar Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For integer types, this should be clearly allowed.
For float, everything else being fine, I think in the absence of any rounding modes, this should be permissible as well.

My worry would be that the reduction order is inherently unspecified for this op, and constant folding may produce a different order than runtime evaluation. In the current form, I can see that this is only applied in the splat case, which I think would only produce different results if the runtime implementation ended up doing partial reductions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of partial reductions- we would have less accuracy than this right? So do we prefer, that the default behavior be that our program runs slower and give less accurate results?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we prefer, that the default behavior be that our program runs slower and give less accurate results?

I don't think this a fair way to put this. The concern is that the same computation gives two answers depending on the operands being compile-time constants or not, assuming reasonable lowering. I think the argument you are trying to make is that partial reductions should not be considered a valid lowering, and I'm honestly not sure either way because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu::SubgroupReduceOp and vector::ReductionOp don't fold when the argument is a constant. Maybe it would help us get unstack if we decide what to do with these two first? (This could be an RFC on discourse.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through other vector dialect ops and I don't see any other 'math' op folding. Because what you propose here is something new, I think it deserves an RFC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.

The concern is that this may make something that is runtime/hw-dependent (and it does not need to be), not hw dependent? Something that:

  1. Does not need to be rte/hw-dependent. It's a constant.
  2. There is no runtime cost to this, it would actually be faster to fold the constant.
  3. There is no precision cost to it, it would in fact be more accurate.
    In this case, it's Splat-Splat, but in general, partial reduction does not even return consistent result on the same hw because the order by which it is applied may change from run-to-run. How is that desirable? How is that even consistent? If partial ordering should be canon for reductions- in what order should it be applied then?
    I've seen discussion on this where it has been a decision between being fast or not hw-dependent. Being fast or being more accurate. This is neither.
    Why would we prefer to be needlessly hw-dependent, less accurate and slower? It's fine if some user somehow wants that- but why should it be the default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this discussion to an RFC. You are proposing much more aggressive folds than any existing ones in vector AFAICT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImanHosseini added 2 commits January 13, 2025 20:52
case CombiningKind::ADD: {
APFloat n = APFloat(srcVal.getSemantics());
n.convertFromAPInt(APInt(64, times, true), true,
APFloat::rmNearestTiesToEven);
Copy link
Member

@kuhar kuhar Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For integer types, this should be clearly allowed.
For float, everything else being fine, I think in the absence of any rounding modes, this should be permissible as well.

My worry would be that the reduction order is inherently unspecified for this op, and constant folding may produce a different order than runtime evaluation. In the current form, I can see that this is only applied in the splat case, which I think would only produce different results if the runtime implementation ended up doing partial reductions?

@ImanHosseini ImanHosseini requested a review from jpienaar January 13, 2025 22:11
@kuhar kuhar mentioned this pull request Jan 16, 2025
ImanHosseini added a commit that referenced this pull request Jan 17, 2025
I am trying to calculate power function for APFloat, APInt to constant
fold vector reductions: #122450
I need this utility to fold N `mul`s into power.

---------

Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 17, 2025
I am trying to calculate power function for APFloat, APInt to constant
fold vector reductions: llvm/llvm-project#122450
I need this utility to fold N `mul`s into power.

---------

Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants