[mlir][sparse] Misc code cleanup

Depends On D111763

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D111766
This commit is contained in:
wren romano 2021-10-13 16:08:35 -07:00
parent 63d4fc9483
commit 5167c36ab4
1 changed files with 8 additions and 11 deletions

View File

@ -189,7 +189,7 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
params.push_back(constantI64(rewriter, loc, secInd));
params.push_back(constantI64(rewriter, loc, primary));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(constantI32(rewriter, loc, action));
@ -226,9 +226,8 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
/// if (tensor[ivs]!=0) {
/// ind = ivs
static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
Operation *op, Value tensor, Value ind,
Location loc, Value tensor, Value ind,
ValueRange ivs) {
Location loc = op->getLoc();
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
Value cond = genIsNonzero(rewriter, loc, val);
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
@ -270,7 +269,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
params.push_back(val);
params.push_back(ind);
params.push_back(perm);
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
rewriter.create<CallOp>(
loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
params);
@ -279,11 +278,10 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
/// If the tensor is a sparse constant, generates and returns the pair of
/// the constants for the indices and the values.
static Optional<std::pair<Value, Value>>
genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
Value tensor) {
if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
Location loc = op->getLoc();
DenseElementsAttr indicesAttr = attr.getIndices();
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
DenseElementsAttr valuesAttr = attr.getValues();
@ -297,10 +295,9 @@ genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
/// Generates the code to copy the index at indices[ivs] to ind, and return
/// the value at value[ivs].
static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
Operation *op, Value indices,
Location loc, Value indices,
Value values, Value ind, ValueRange ivs,
unsigned rank) {
Location loc = op->getLoc();
for (unsigned i = 0; i < rank; i++) {
Value idx = constantIndex(rewriter, loc, i);
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
@ -449,7 +446,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value> st;
Value zero = constantIndex(rewriter, loc, 0);
Value one = constantIndex(rewriter, loc, 1);
auto indicesValues = genSplitSparseConstant(rewriter, op, src);
auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
bool isCOOConstant = indicesValues.hasValue();
Value indices;
Value values;
@ -474,10 +471,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
ValueRange args) -> scf::ValueVector {
Value val;
if (isCOOConstant)
val = genIndexAndValueForSparse(rewriter, op, indices, values, ind,
val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
ivs, rank);
else
val = genIndexAndValueForDense(rewriter, op, src, ind, ivs);
val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
return {};
});