diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 99214fadf4ba3..76ecd4171d81a 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1665,11 +1665,10 @@ LogicalResult ForeachOp::verify() { const Dimension dimRank = t.getDimRank(); const auto args = getBody()->getArguments(); - if (getOrder().has_value() && - (t.getEncoding() || !getOrder()->isPermutation())) - return emitError("Only support permuted order on non encoded dense tensor"); + if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) + return emitError("Level traverse order does not match tensor's level rank"); - if (static_cast(dimRank) + 1 + getInitArgs().size() != args.size()) + if (dimRank + 1 + getInitArgs().size() != args.size()) return emitError("Unmatched number of arguments in the block"); if (getNumResults() != getInitArgs().size()) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index f6fb59fa2c3b8..db969436a3071 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) { void sparse_tensor::foreachInSparseConstant( OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order, function_ref, Value)> callback) { - const Dimension dimRank = - SparseTensorType(getRankedTensorType(attr)).getDimRank(); + if (!order) + order = builder.getMultiDimIdentityMap(attr.getType().getRank()); + + auto stt = SparseTensorType(getRankedTensorType(attr)); + const Dimension dimRank = stt.getDimRank(); const auto coordinates = attr.getIndices().getValues(); const auto values = attr.getValues().getValues(); @@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant( // Sorts the sparse element attribute based on coordinates. std::sort(elems.begin(), elems.end(), - [order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) { - const auto &lhsCoords = lhs.first; - const auto &rhsCoords = rhs.first; - for (Dimension d = 0; d < dimRank; d++) { - // FIXME: This only makes sense for permutations. - // And since we don't check that `order` is a permutation, - // it can also cause OOB errors when we use `l`. - const Level l = order ? order.getDimPosition(d) : d; - if (lhsCoords[l].getInt() == rhsCoords[l].getInt()) - continue; - return lhsCoords[l].getInt() < rhsCoords[l].getInt(); - } + [order](const ElementAttr &lhs, const ElementAttr &rhs) { if (std::addressof(lhs) == std::addressof(rhs)) return false; + + auto lhsCoords = llvm::map_to_vector( + lhs.first, [](IntegerAttr i) { return i.getInt(); }); + auto rhsCoords = llvm::map_to_vector( + rhs.first, [](IntegerAttr i) { return i.getInt(); }); + + SmallVector lhsLvlCrds = order.compose(lhsCoords); + SmallVector rhsLvlCrds = order.compose(rhsCoords); + // Sort the element based on the lvl coordinates. + for (Level l = 0; l < order.getNumResults(); l++) { + if (lhsLvlCrds[l] == rhsLvlCrds[l]) + continue; + return lhsLvlCrds[l] < rhsLvlCrds[l]; + } llvm_unreachable("no equal coordinate in sparse element attr"); }); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 13388dce6bbb5..7770bd857e880 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern { SmallVector lcvs = loopEmitter.getLoopIVs(); if (op.getOrder()) { - // FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank` - const Dimension dimRank = stt.getDimRank(); - SmallVector dcvs = lcvs; // keep a copy - for (Dimension d = 0; d < dimRank; d++) { - auto l = op.getOrder()->getDimPosition(d); - lcvs[l] = dcvs[d]; - } + // TODO: Support it so that we can do direct conversion from CSR->BSR. + llvm_unreachable( + "Level order not yet implemented on non-constant input tensors."); } + Value vals = loopEmitter.getValBuffer()[0]; Value pos = loopEmitter.getPosits()[0].back(); // Loads the value from sparse tensor using position-index; diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir index ec14492c5b449..ccd61aea0ba16 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_block.mlir @@ -74,26 +74,34 @@ module { // // Initialize a 2-dim dense tensor. // - %t = arith.constant dense<[ - [ 1.0, 2.0, 3.0, 4.0 ], - [ 5.0, 6.0, 7.0, 8.0 ] - ]> : tensor<2x4xf64> + %t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3], + [1, 0], [1, 1], [1, 2], [1, 3]], + [ 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64> + %td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ], + [ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64> - %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR> - %2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR> - %3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC> + // constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute) + %1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR> + %2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR> + %3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR> + %4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC> - %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref + %v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref %v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref - %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref + %v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref + %v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref - // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 ) + + // CHECK: ( 1, 2, 5, 6, 3, 4, 7, 8 ) // CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 ) // CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 ) call @dumpf64(%v1) : (memref) -> () call @dumpf64(%v2) : (memref) -> () call @dumpf64(%v3) : (memref) -> () + call @dumpf64(%v4) : (memref) -> () return }