Open
Description
Hi!
I'm trying to reshape CSR/CSC tensors. While reshaping CSR tensors works, running the PassManager
on CSC reshape fails with:
Traceback (most recent call last):
File "/home/mtsokol/sparse/csc_reshape.py", line 29, in <module>
pm.run(module.operation)
mlir._mlir_libs._site_initialize.<locals>.MLIRError: Failure while executing pass pipeline:
error: "-":6:16: ConvertOp not staged.
note: "-":6:16: see current operation: %8 = "sparse_tensor.convert"(%7) : (tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)), posWidth = 64, crdWidth = 64 }>>) -> tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 64, crdWidth = 64 }>>
error: "-":6:16: ConvertOp not staged.
note: "-":6:16: see current operation: %9 = "sparse_tensor.convert"(%8) : (tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)), posWidth = 64, crdWidth = 64 }>>) -> tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 64, crdWidth = 64 }>>
error: "-":6:16: ConvertOp not staged.
note: "-":6:16: see current operation: %6 = "sparse_tensor.convert"(%5) : (tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)), posWidth = 64, crdWidth = 64 }>>) -> tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 64, crdWidth = 64 }>>
error: "-":6:16: failed to legalize operation 'sparse_tensor.convert' that was explicitly marked illegal
note: "-":6:16: see current operation: %44 = "sparse_tensor.convert"(%43) : (tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered)), posWidth = 64, crdWidth = 64 }>>) -> tensor<25x200xf64, #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 64, crdWidth = 64 }>>
Here's a short script that reproduces failing CSC reshape: https://gist.github.com/mtsokol/cc67c576172c67141126307713f4cb96
import ctypes
import ctypes.util
import pathlib
from mlir.ir import Context, Module
from mlir import execution_engine, passmanager
MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
with Context():
module = Module.parse(
"""
#CSC = #sparse_tensor.encoding<{
map = (i, j) -> (j : dense, i : compressed), posWidth = 64, crdWidth = 64
}>
func.func @add(%st_0 : tensor<100x50xf64, #CSC>, %st_1 : tensor<2xi64>) -> tensor<25x200xf64, #CSC> attributes { llvm.emit_c_interface } {
%dst = tensor.reshape %st_0(%st_1) : (tensor<100x50xf64, #CSC>, tensor<2xi64>) -> tensor<25x200xf64, #CSC>
return %dst : tensor<25x200xf64, #CSC>
}
"""
)
CWD = pathlib.Path(".")
(CWD / "module.mlir").write_text(str(module))
pm = passmanager.PassManager.parse("builtin.module(sparse-assembler{direct-out=true}, sparsifier{create-sparse-deallocs=1 enable-runtime-library=false})")
pm.run(module.operation)
(CWD / "module_opt.mlir").write_text(str(module))
ee = execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
And here a script showing CSR reshape that works: https://gist.github.com/mtsokol/a9950a60f39983bab6b365081c8d2b3a
import ctypes
import ctypes.util
import pathlib
from mlir.ir import Context, Module
from mlir import execution_engine, passmanager
MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
with Context():
module = Module.parse(
"""
#CSR = #sparse_tensor.encoding<{
map = (i, j) -> (i : dense, j : compressed), posWidth = 64, crdWidth = 64
}>
func.func @add(%st_0 : tensor<100x50xf64, #CSR>, %st_1 : tensor<2xi64>) -> tensor<25x200xf64, #CSR> attributes { llvm.emit_c_interface } {
%dst = tensor.reshape %st_0(%st_1) : (tensor<100x50xf64, #CSR>, tensor<2xi64>) -> tensor<25x200xf64, #CSR>
return %dst : tensor<25x200xf64, #CSR>
}
"""
)
CWD = pathlib.Path(".")
(CWD / "module.mlir").write_text(str(module))
pm = passmanager.PassManager.parse("builtin.module(sparse-assembler{direct-out=true}, sparsifier{create-sparse-deallocs=1 enable-runtime-library=false})")
pm.run(module.operation)
(CWD / "module_opt.mlir").write_text(str(module))
ee = execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
I'm using LLVM 19.0.1-rc3