forked from OSchip/llvm-project
[mlir][sparse] Misc code cleanup
Depends On D111763 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D111766
This commit is contained in:
parent
63d4fc9483
commit
5167c36ab4
|
@ -189,7 +189,7 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
params.push_back(constantI64(rewriter, loc, secInd));
|
||||
params.push_back(constantI64(rewriter, loc, primary));
|
||||
// User action and pointer.
|
||||
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
if (!ptr)
|
||||
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
|
||||
params.push_back(constantI32(rewriter, loc, action));
|
||||
|
@ -226,9 +226,8 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
|||
/// if (tensor[ivs]!=0) {
|
||||
/// ind = ivs
|
||||
static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value tensor, Value ind,
|
||||
Location loc, Value tensor, Value ind,
|
||||
ValueRange ivs) {
|
||||
Location loc = op->getLoc();
|
||||
Value val = rewriter.create<tensor::ExtractOp>(loc, tensor, ivs);
|
||||
Value cond = genIsNonzero(rewriter, loc, val);
|
||||
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
|
||||
|
@ -270,7 +269,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
params.push_back(val);
|
||||
params.push_back(ind);
|
||||
params.push_back(perm);
|
||||
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
rewriter.create<CallOp>(
|
||||
loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
|
||||
params);
|
||||
|
@ -279,11 +278,10 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
/// If the tensor is a sparse constant, generates and returns the pair of
|
||||
/// the constants for the indices and the values.
|
||||
static Optional<std::pair<Value, Value>>
|
||||
genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
|
||||
genSplitSparseConstant(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value tensor) {
|
||||
if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
|
||||
if (auto attr = constOp.value().dyn_cast<SparseElementsAttr>()) {
|
||||
Location loc = op->getLoc();
|
||||
DenseElementsAttr indicesAttr = attr.getIndices();
|
||||
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
|
||||
DenseElementsAttr valuesAttr = attr.getValues();
|
||||
|
@ -297,10 +295,9 @@ genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op,
|
|||
/// Generates the code to copy the index at indices[ivs] to ind, and return
|
||||
/// the value at value[ivs].
|
||||
static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value indices,
|
||||
Location loc, Value indices,
|
||||
Value values, Value ind, ValueRange ivs,
|
||||
unsigned rank) {
|
||||
Location loc = op->getLoc();
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
Value idx = constantIndex(rewriter, loc, i);
|
||||
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
|
||||
|
@ -449,7 +446,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
SmallVector<Value> st;
|
||||
Value zero = constantIndex(rewriter, loc, 0);
|
||||
Value one = constantIndex(rewriter, loc, 1);
|
||||
auto indicesValues = genSplitSparseConstant(rewriter, op, src);
|
||||
auto indicesValues = genSplitSparseConstant(rewriter, loc, src);
|
||||
bool isCOOConstant = indicesValues.hasValue();
|
||||
Value indices;
|
||||
Value values;
|
||||
|
@ -474,10 +471,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
ValueRange args) -> scf::ValueVector {
|
||||
Value val;
|
||||
if (isCOOConstant)
|
||||
val = genIndexAndValueForSparse(rewriter, op, indices, values, ind,
|
||||
val = genIndexAndValueForSparse(rewriter, loc, indices, values, ind,
|
||||
ivs, rank);
|
||||
else
|
||||
val = genIndexAndValueForDense(rewriter, op, src, ind, ivs);
|
||||
val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
|
||||
genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
|
||||
return {};
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue