Skip to content

[mlir][sparse] support sparse constant to BSR conversion. #71114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1665,11 +1665,10 @@ LogicalResult ForeachOp::verify() {
const Dimension dimRank = t.getDimRank();
const auto args = getBody()->getArguments();

if (getOrder().has_value() &&
(t.getEncoding() || !getOrder()->isPermutation()))
return emitError("Only support permuted order on non encoded dense tensor");
if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
return emitError("Level traverse order does not match tensor's level rank");

if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
if (dimRank + 1 + getInitArgs().size() != args.size())
return emitError("Unmatched number of arguments in the block");

if (getNumResults() != getInitArgs().size())
Expand Down
34 changes: 20 additions & 14 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,11 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
void sparse_tensor::foreachInSparseConstant(
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
function_ref<void(ArrayRef<Value>, Value)> callback) {
const Dimension dimRank =
SparseTensorType(getRankedTensorType(attr)).getDimRank();
if (!order)
order = builder.getMultiDimIdentityMap(attr.getType().getRank());

auto stt = SparseTensorType(getRankedTensorType(attr));
const Dimension dimRank = stt.getDimRank();
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
const auto values = attr.getValues().getValues<Attribute>();

Expand All @@ -446,20 +449,23 @@ void sparse_tensor::foreachInSparseConstant(

// Sorts the sparse element attribute based on coordinates.
std::sort(elems.begin(), elems.end(),
[order, dimRank](const ElementAttr &lhs, const ElementAttr &rhs) {
const auto &lhsCoords = lhs.first;
const auto &rhsCoords = rhs.first;
for (Dimension d = 0; d < dimRank; d++) {
// FIXME: This only makes sense for permutations.
// And since we don't check that `order` is a permutation,
// it can also cause OOB errors when we use `l`.
const Level l = order ? order.getDimPosition(d) : d;
if (lhsCoords[l].getInt() == rhsCoords[l].getInt())
continue;
return lhsCoords[l].getInt() < rhsCoords[l].getInt();
}
[order](const ElementAttr &lhs, const ElementAttr &rhs) {
if (std::addressof(lhs) == std::addressof(rhs))
return false;

auto lhsCoords = llvm::map_to_vector(
lhs.first, [](IntegerAttr i) { return i.getInt(); });
auto rhsCoords = llvm::map_to_vector(
rhs.first, [](IntegerAttr i) { return i.getInt(); });

SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
// Sort the element based on the lvl coordinates.
for (Level l = 0; l < order.getNumResults(); l++) {
if (lhsLvlCrds[l] == rhsLvlCrds[l])
continue;
return lhsLvlCrds[l] < rhsLvlCrds[l];
}
llvm_unreachable("no equal coordinate in sparse element attr");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1129,14 +1129,11 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {

SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
// FIXME: There is some dim/lvl confusion here since `dimRank != lvlRank`
const Dimension dimRank = stt.getDimRank();
SmallVector<Value> dcvs = lcvs; // keep a copy
for (Dimension d = 0; d < dimRank; d++) {
auto l = op.getOrder()->getDimPosition(d);
lcvs[l] = dcvs[d];
}
// TODO: Support it so that we can do direct conversion from CSR->BSR.
llvm_unreachable(
"Level order not yet implemented on non-constant input tensors.");
}

Value vals = loopEmitter.getValBuffer()[0];
Value pos = loopEmitter.getPosits()[0].back();
// Loads the value from sparse tensor using position-index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,34 @@ module {
//
// Initialize a 2-dim dense tensor.
//
%t = arith.constant dense<[
[ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ]
]> : tensor<2x4xf64>
%t = arith.constant sparse<[[0, 0], [0, 1], [0, 2], [0, 3],
[1, 0], [1, 1], [1, 2], [1, 3]],
[ 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0 ]> : tensor<2x4xf64>

%td = arith.constant dense<[[ 1.0, 2.0, 3.0, 4.0 ],
[ 5.0, 6.0, 7.0, 8.0 ]]> : tensor<2x4xf64>

%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
%2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
%3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
// constant -> BSR (either from SparseElementAttibutes or DenseElementAttribute)
%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
%2 = sparse_tensor.convert %td : tensor<2x4xf64> to tensor<2x4xf64, #BSR>
%3 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSR>
%4 = sparse_tensor.convert %1 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>

%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #BSR> to memref<?xf64>
%v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSR> to memref<?xf64>
%v4 = sparse_tensor.values %4 : tensor<2x4xf64, #CSC> to memref<?xf64>

// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )

// CHECK: ( 1, 2, 5, 6, 3, 4, 7, 8 )
// CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
// CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8 )
// CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
call @dumpf64(%v1) : (memref<?xf64>) -> ()
call @dumpf64(%v2) : (memref<?xf64>) -> ()
call @dumpf64(%v3) : (memref<?xf64>) -> ()
call @dumpf64(%v4) : (memref<?xf64>) -> ()

return
}
Expand Down