[mlir][sparse] Code cleanup for SparseTensorConversion

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D115004
This commit is contained in:
wren romano 2021-12-03 15:58:03 -08:00
parent 0e0f1b28fc
commit f527fdf51e
1 changed files with 58 additions and 55 deletions

View File

@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
}
/// Returns the equivalent of `void*` for opaque arguments to the
/// execution engine.
static Type getOpaquePointerType(PatternRewriter &rewriter) {
return LLVM::LLVMPointerType::get(rewriter.getI8Type());
}
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
bool emitCInterface = false) {
bool emitCInterface) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
auto result = SymbolRefAttr::get(context, name);
@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
return result;
}
/// Creates a `CallOp` to the function reference returned by `getFunc()`.
static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
bool emitCInterface = false) {
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
}
/// Replaces the `op` with a `CallOp` to the function reference returned
/// by `getFunc()`.
static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
StringRef name, TypeRange resultType,
ValueRange operands,
bool emitCInterface = false) {
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
}
/// Generates dimension size call.
static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
SparseTensorEncodingAttr &enc, Value src,
@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
if (AffineMap p = enc.getDimOrdering())
idx = p.getPermutedPosition(idx);
// Generate the call.
Location loc = op->getLoc();
StringRef name = "sparseDimSize";
SmallVector<Value, 2> params;
params.push_back(src);
params.push_back(constantIndex(rewriter, loc, idx));
SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
Type iTp = rewriter.getIndexType();
auto fn = getFunc(op, name, iTp, params);
return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
/// support library for materializing sparse tensors into the computation.
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
ArrayRef<Value> params) {
Location loc = op->getLoc();
StringRef name = "newSparseTensor";
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
Type pTp = getOpaquePointerType(rewriter);
auto call = createFuncCall(rewriter, op, name, pTp, params,
/*emitCInterface=*/true);
return call.getResult(0);
}
@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
static void sizesFromSrc(ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &sizes, Location loc,
Value src) {
ShapedType stp = src.getType().cast<ShapedType>();
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
unsigned rank = src.getType().cast<ShapedType>().getRank();
for (unsigned i = 0; i < rank; i++)
sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
}
@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &sizes, Operation *op,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value src) {
Location loc = op->getLoc();
auto shape = stp.getShape();
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
if (shape[i] == ShapedType::kDynamicSize)
sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
else
sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
sizes.push_back(constantIndex(rewriter, loc, shape[i]));
}
/// Generates an uninitialized temporary buffer of the given size and
@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
}
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
params.push_back(
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
// User action.
params.push_back(constantAction(rewriter, loc, action));
// Payload pointer.
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
params.push_back(ptr);
}
@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
Type eltType, Value ptr, Value val, Value ind,
Value perm) {
Location loc = op->getLoc();
StringRef name;
if (eltType.isF64())
name = "addEltF64";
@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
name = "addEltI8";
else
llvm_unreachable("Unknown element type");
SmallVector<Value, 8> params;
params.push_back(ptr);
params.push_back(val);
params.push_back(ind);
params.push_back(perm);
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
rewriter.create<CallOp>(loc, pTp, fn, params);
SmallVector<Value, 4> params{ptr, val, ind, perm};
Type pTp = getOpaquePointerType(rewriter);
createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
}
/// Generates a call to `iter->getNext()`. If there is a next element,
@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
/// the memory for `iter` is freed and the return value is false.
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
Value iter, Value ind, Value elemPtr) {
Location loc = op->getLoc();
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
StringRef name;
if (elemTp.isF64())
@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
name = "getNextI8";
else
llvm_unreachable("Unknown element type");
SmallVector<Value, 3> params;
params.push_back(iter);
params.push_back(ind);
params.push_back(elemPtr);
SmallVector<Value, 3> params{iter, ind, elemPtr};
Type i1 = rewriter.getI1Type();
auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
auto call = rewriter.create<CallOp>(loc, i1, fn, params);
auto call = createFuncCall(rewriter, op, name, i1, params,
/*emitCInterface=*/true);
return call.getResult(0);
}
@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
}
Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
Value zero = constantZero(rewriter, loc, elemTp);
rewriter.create<linalg::FillOp>(loc, zero, mem).result();
rewriter.create<linalg::FillOp>(loc, zero, mem);
return mem;
}
@ -754,9 +763,8 @@ public:
matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef name = "delSparseTensor";
TypeRange none;
auto fn = getFunc(op, name, none, adaptor.getOperands());
rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
TypeRange noTp;
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
rewriter.eraseOp(op);
return success();
}
@ -785,9 +793,8 @@ public:
name = "sparsePointers8";
else
return failure();
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
return success();
}
};
@ -814,9 +821,8 @@ public:
name = "sparseIndices8";
else
return failure();
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
return success();
}
};
@ -845,9 +851,8 @@ public:
name = "sparseValuesI8";
else
return failure();
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
/*emitCInterface=*/true);
return success();
}
};
@ -863,8 +868,7 @@ public:
// Finalize any pending insertions.
StringRef name = "endInsert";
TypeRange noTp;
auto fn = getFunc(op, name, noTp, adaptor.getOperands());
rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
}
rewriter.replaceOp(op, adaptor.getOperands());
return success();
@ -896,9 +900,8 @@ public:
else
llvm_unreachable("Unknown element type");
TypeRange noTp;
auto fn =
getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
/*emitCInterface=*/true);
return success();
}
};