-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR] [Vector] ConstantFold MultiDReduction #122450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR] [Vector] ConstantFold MultiDReduction #122450
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Iman Hosseini (ImanHosseini) ChangesIf both source and acc are splat, constant fold the multi-reduction. Full diff: https://github.com/llvm/llvm-project/pull/122450.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..a23d952c5760c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -21,11 +21,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
@@ -44,6 +46,7 @@
#include "llvm/ADT/bit.h"
#include <cassert>
+#include <cmath>
#include <cstdint>
#include <numeric>
@@ -463,10 +466,111 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
build(builder, result, kind, source, acc, reductionDims);
}
+template <typename T>
+OpFoldResult foldSplatReduce(T src, T acc, int64_t times, CombiningKind kind,
+ ShapedType dstType);
+
+template <>
+OpFoldResult foldSplatReduce(FloatAttr src, FloatAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APFloat srcVal = src.getValue();
+ APFloat accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD:{
+ APFloat n = APFloat(srcVal.getSemantics());
+ n.convertFromAPInt(APInt(64, times, true), true,
+ APFloat::rmNearestTiesToEven);
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * n});
+ }
+ case CombiningKind::MUL: {
+ APFloat result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result = result * srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minimum(accVal, srcVal)});
+ case CombiningKind::MAXIMUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maximum(accVal, srcVal)});
+ case CombiningKind::MINNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::minnum(accVal, srcVal)});
+ case CombiningKind::MAXNUMF:
+ return DenseElementsAttr::get(dstType, {llvm::maxnum(accVal, srcVal)});
+ default:
+ return {};
+ }
+}
+
+template <>
+OpFoldResult foldSplatReduce(IntegerAttr src, IntegerAttr acc, int64_t times,
+ CombiningKind kind, ShapedType dstType) {
+ APInt srcVal = src.getValue();
+ APInt accVal = acc.getValue();
+ switch (kind) {
+ case CombiningKind::ADD:
+ return DenseElementsAttr::get(dstType, {accVal + srcVal * times});
+ case CombiningKind::MUL: {
+ APInt result = accVal;
+ for (int i = 0; i < times; ++i) {
+ result *= srcVal;
+ }
+ return DenseElementsAttr::get(dstType, {result});
+ }
+ case CombiningKind::MINSI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.slt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXSI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::MINUI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ult(srcVal) ? accVal : srcVal});
+ case CombiningKind::MAXUI:
+ return DenseElementsAttr::get(
+ dstType, {accVal.ugt(srcVal) ? accVal : srcVal});
+ case CombiningKind::AND:
+ return DenseElementsAttr::get(dstType, {accVal & srcVal});
+ case CombiningKind::OR:
+ return DenseElementsAttr::get(dstType, {accVal | srcVal});
+ case CombiningKind::XOR:
+ return DenseElementsAttr::get(dstType,
+ {times & 0x1 ? accVal ^ srcVal : accVal});
+ default:
+ return {};
+ }
+}
+
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
// Single parallel dim, this is a noop.
if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
return getSource();
+ auto srcAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSource());
+ auto accAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getAcc());
+ if (!srcAttr || !accAttr)
+ return {};
+ if (!srcAttr.isSplat() || !accAttr.isSplat())
+ return {};
+ auto reductionDims = getReductionDims();
+ auto srcType = mlir::cast<ShapedType>(getSourceVectorType());
+ auto srcDims = srcType.getShape();
+ int64_t times = 1;
+ for (auto dim : reductionDims) {
+ times *= srcDims[dim];
+ }
+ CombiningKind kind = getKind();
+ auto dstType = mlir::cast<ShapedType>(getDestType());
+ auto eltype = dstType.getElementType();
+ if (mlir::dyn_cast_or_null<FloatType>(eltype)) {
+ return foldSplatReduce<FloatAttr>(srcAttr.getSplatValue<FloatAttr>(),
+ accAttr.getSplatValue<FloatAttr>(), times,
+ kind, dstType);
+ }
+ if (mlir::dyn_cast_or_null<IntegerType>(eltype)) {
+ return foldSplatReduce<IntegerAttr>(srcAttr.getSplatValue<IntegerAttr>(),
+ accAttr.getSplatValue<IntegerAttr>(),
+ times, kind, dstType);
+ }
return {};
}
diff --git a/mlir/test/Dialect/Vector/constant-fold.mlir b/mlir/test/Dialect/Vector/constant-fold.mlir
index 66c91d6b2041bf..43c52b4b36ca53 100644
--- a/mlir/test/Dialect/Vector/constant-fold.mlir
+++ b/mlir/test/Dialect/Vector/constant-fold.mlir
@@ -11,3 +11,57 @@ func.func @fold_extract_transpose_negative(%arg0: vector<4x4xf16>) -> vector<4x4
%2 = vector.extract %1[0] : vector<4x4xf16> from vector<1x4x4xf16>
return %2 : vector<4x4xf16>
}
+
+// CHECK-LABEL: fold_multid_reduction_f32_add
+func.func @fold_multid_reduction_f32_add() -> vector<1xf32> {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<1.000000e+00> : vector<1x128x128xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.638400e+04> : vector<1xf32>
+ %1 = vector.multi_reduction <add>, %0, %cst_0 [1, 2] : vector<1x128x128xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_f32_mul
+func.func @fold_multid_reduction_f32_mul() -> vector<1xf32> {
+ %cst_0 = arith.constant dense<1.000000e+00> : vector<1xf32>
+ %0 = arith.constant dense<2.000000e+00> : vector<1x2x2xf32>
+ // CHECK: %{{.*}} = arith.constant dense<1.600000e+01> : vector<1xf32>
+ %1 = vector.multi_reduction <mul>, %0, %cst_0 [1, 2] : vector<1x2x2xf32> to vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_add
+func.func @fold_multid_reduction_i32_add() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<1> : vector<1xi32>
+ %0 = arith.constant dense<1> : vector<1x128x128xi32>
+ // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi32>
+ %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_odd
+func.func @fold_multid_reduction_i32_xor_odd() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+ %0 = arith.constant dense<0xA0A> : vector<1x3xi32>
+ // CHECK: %{{.*}} = arith.constant dense<2805> : vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x3xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i32_xor_even
+func.func @fold_multid_reduction_i32_xor_even() -> vector<1xi32> {
+ %cst_1 = arith.constant dense<0xFF> : vector<1xi32>
+ %0 = arith.constant dense<0xA0A> : vector<1x4xi32>
+ // CHECK: %{{.*}} = arith.constant dense<255> : vector<1xi32>
+ %1 = vector.multi_reduction <xor>, %0, %cst_1 [1] : vector<1x4xi32> to vector<1xi32>
+ return %1 : vector<1xi32>
+}
+
+// CHECK-LABEL: fold_multid_reduction_i64_add
+func.func @fold_multid_reduction_i64_add() -> vector<1xi64> {
+ %cst_1 = arith.constant dense<1> : vector<1xi64>
+ %0 = arith.constant dense<1> : vector<1x128x128xi64>
+ // CHECK: %{{.*}} = arith.constant dense<16385> : vector<1xi64>
+ %1 = vector.multi_reduction <add>, %0, %cst_1 [1, 2] : vector<1x128x128xi64> to vector<1xi64>
+ return %1 : vector<1xi64>
+}
\ No newline at end of file
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
make it static. Combine check. Rename eltype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
I've left a few minor comments - please address those before landing. Btw, do you have commit access?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, do not land yet. One sec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
case CombiningKind::ADD: { | ||
APFloat n = APFloat(srcVal.getSemantics()); | ||
n.convertFromAPInt(APInt(64, times, true), true, | ||
APFloat::rmNearestTiesToEven); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be a problem is the user expects a non-default rounding mode? We have been adding FMF and RM bottom-up in the IR but it's lacking at vector level so I'm wondering if this would lead to an unexpected outcome. Perhaps @chelini, @kuhar could provide some feedback?
Worse case, I guess we could enable this folder under a flag...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. How does the user currently get non-default rounding mode? & it does not look like this is the only fold
that would be affected. In Arith I see a TODO: https://mlir.llvm.org/docs/Dialects/ArithOps/#arithaddf-arithaddfop for adding optional attributes to specify that. So the issue of rounding mode seems a separate problem than constant folding and once support for rounding mode is added it can be added here, among with other affected folds, as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For integer types, this should be clearly allowed.
For float, everything else being fine, I think in the absence of any rounding modes, this should be permissible as well.
My worry would be that the reduction order is inherently unspecified for this op, and constant folding may produce a different order than runtime evaluation. In the current form, I can see that this is only applied in the splat case, which I think would only produce different results if the runtime implementation ended up doing partial reductions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of partial reductions- we would have less accuracy than this right? So do we prefer, that the default behavior be that our program runs slower and give less accurate results?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So do we prefer, that the default behavior be that our program runs slower and give less accurate results?
I don't think this a fair way to put this. The concern is that the same computation gives two answers depending on the operands being compile-time constants or not, assuming reasonable lowering. I think the argument you are trying to make is that partial reductions should not be considered a valid lowering, and I'm honestly not sure either way because I could see this being expanded to something like gpu.subgroup_reduce
that may end up doing something very much hardware-dependent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gpu::SubgroupReduceOp
and vector::ReductionOp
don't fold when the argument is a constant. Maybe it would help us get unstack if we decide what to do with these two first? (This could be an RFC on discourse.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went through other vector dialect ops and I don't see any other 'math' op folding. Because what you propose here is something new, I think it deserves an RFC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because I could see this being expanded to something like gpu.subgroup_reduce that may end up doing something very much hardware-dependent.
The concern is that this may make something that is runtime/hw-dependent (and it does not need to be), not hw dependent? Something that:
- Does not need to be rte/hw-dependent. It's a constant.
- There is no runtime cost to this, it would actually be faster to fold the constant.
- There is no precision cost to it, it would in fact be more accurate.
In this case, it's Splat-Splat, but in general, partial reduction does not even return consistent result on the same hw because the order by which it is applied may change from run-to-run. How is that desirable? How is that even consistent? If partial ordering should be canon for reductions- in what order should it be applied then?
I've seen discussion on this where it has been a decision between being fast or not hw-dependent. Being fast or being more accurate. This is neither.
Why would we prefer to be needlessly hw-dependent, less accurate and slower? It's fine if some user somehow wants that- but why should it be the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this discussion to an RFC. You are proposing much more aggressive folds than any existing ones in vector
AFAICT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
case CombiningKind::ADD: { | ||
APFloat n = APFloat(srcVal.getSemantics()); | ||
n.convertFromAPInt(APInt(64, times, true), true, | ||
APFloat::rmNearestTiesToEven); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For integer types, this should be clearly allowed.
For float, everything else being fine, I think in the absence of any rounding modes, this should be permissible as well.
My worry would be that the reduction order is inherently unspecified for this op, and constant folding may produce a different order than runtime evaluation. In the current form, I can see that this is only applied in the splat case, which I think would only produce different results if the runtime implementation ended up doing partial reductions?
I am trying to calculate power function for APFloat, APInt to constant fold vector reductions: #122450 I need this utility to fold N `mul`s into power. --------- Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com> Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
I am trying to calculate power function for APFloat, APInt to constant fold vector reductions: llvm/llvm-project#122450 I need this utility to fold N `mul`s into power. --------- Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com> Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
If both source and acc are splat, constant fold the multi-reduction. @apaszke