forked from OSchip/llvm-project
[mlir][sparse] Code cleanup for SparseTensorConversion
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D115004
This commit is contained in:
parent
0e0f1b28fc
commit
f527fdf51e
|
@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
|
|||
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
|
||||
}
|
||||
|
||||
/// Returns the equivalent of `void*` for opaque arguments to the
|
||||
/// execution engine.
|
||||
static Type getOpaquePointerType(PatternRewriter &rewriter) {
|
||||
return LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
}
|
||||
|
||||
/// Returns a function reference (first hit also inserts into module). Sets
|
||||
/// the "_emit_c_interface" on the function declaration when requested,
|
||||
/// so that LLVM lowering generates a wrapper function that takes care
|
||||
/// of ABI complications with passing in and returning MemRefs to C functions.
|
||||
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
|
||||
TypeRange resultType, ValueRange operands,
|
||||
bool emitCInterface = false) {
|
||||
bool emitCInterface) {
|
||||
MLIRContext *context = op->getContext();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto result = SymbolRefAttr::get(context, name);
|
||||
|
@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Creates a `CallOp` to the function reference returned by `getFunc()`.
|
||||
static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
|
||||
TypeRange resultType, ValueRange operands,
|
||||
bool emitCInterface = false) {
|
||||
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
|
||||
return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
|
||||
}
|
||||
|
||||
/// Replaces the `op` with a `CallOp` to the function reference returned
|
||||
/// by `getFunc()`.
|
||||
static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
|
||||
StringRef name, TypeRange resultType,
|
||||
ValueRange operands,
|
||||
bool emitCInterface = false) {
|
||||
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
|
||||
return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
|
||||
}
|
||||
|
||||
/// Generates dimension size call.
|
||||
static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
SparseTensorEncodingAttr &enc, Value src,
|
||||
|
@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
if (AffineMap p = enc.getDimOrdering())
|
||||
idx = p.getPermutedPosition(idx);
|
||||
// Generate the call.
|
||||
Location loc = op->getLoc();
|
||||
StringRef name = "sparseDimSize";
|
||||
SmallVector<Value, 2> params;
|
||||
params.push_back(src);
|
||||
params.push_back(constantIndex(rewriter, loc, idx));
|
||||
SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
|
||||
Type iTp = rewriter.getIndexType();
|
||||
auto fn = getFunc(op, name, iTp, params);
|
||||
return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
|
||||
return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call into the "swiss army knife" method of the sparse runtime
|
||||
/// support library for materializing sparse tensors into the computation.
|
||||
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<Value> params) {
|
||||
Location loc = op->getLoc();
|
||||
StringRef name = "newSparseTensor";
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
|
||||
auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
|
||||
Type pTp = getOpaquePointerType(rewriter);
|
||||
auto call = createFuncCall(rewriter, op, name, pTp, params,
|
||||
/*emitCInterface=*/true);
|
||||
return call.getResult(0);
|
||||
}
|
||||
|
||||
|
@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
|
|||
static void sizesFromSrc(ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &sizes, Location loc,
|
||||
Value src) {
|
||||
ShapedType stp = src.getType().cast<ShapedType>();
|
||||
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
|
||||
unsigned rank = src.getType().cast<ShapedType>().getRank();
|
||||
for (unsigned i = 0; i < rank; i++)
|
||||
sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
|
||||
}
|
||||
|
||||
|
@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
|
|||
SmallVector<Value, 4> &sizes, Operation *op,
|
||||
SparseTensorEncodingAttr &enc, ShapedType stp,
|
||||
Value src) {
|
||||
Location loc = op->getLoc();
|
||||
auto shape = stp.getShape();
|
||||
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
|
||||
if (shape[i] == ShapedType::kDynamicSize)
|
||||
sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
|
||||
else
|
||||
sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
|
||||
sizes.push_back(constantIndex(rewriter, loc, shape[i]));
|
||||
}
|
||||
|
||||
/// Generates an uninitialized temporary buffer of the given size and
|
||||
|
@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
params.push_back(genBuffer(rewriter, loc, rev));
|
||||
// Secondary and primary types encoding.
|
||||
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
|
||||
Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
|
||||
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
|
||||
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
|
||||
params.push_back(
|
||||
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
|
||||
// User action and pointer.
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
if (!ptr)
|
||||
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
|
||||
params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
|
||||
// User action.
|
||||
params.push_back(constantAction(rewriter, loc, action));
|
||||
// Payload pointer.
|
||||
if (!ptr)
|
||||
ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
|
||||
params.push_back(ptr);
|
||||
}
|
||||
|
||||
|
@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
|
|||
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Type eltType, Value ptr, Value val, Value ind,
|
||||
Value perm) {
|
||||
Location loc = op->getLoc();
|
||||
StringRef name;
|
||||
if (eltType.isF64())
|
||||
name = "addEltF64";
|
||||
|
@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
name = "addEltI8";
|
||||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
SmallVector<Value, 8> params;
|
||||
params.push_back(ptr);
|
||||
params.push_back(val);
|
||||
params.push_back(ind);
|
||||
params.push_back(perm);
|
||||
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
|
||||
auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
|
||||
rewriter.create<CallOp>(loc, pTp, fn, params);
|
||||
SmallVector<Value, 4> params{ptr, val, ind, perm};
|
||||
Type pTp = getOpaquePointerType(rewriter);
|
||||
createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
|
||||
}
|
||||
|
||||
/// Generates a call to `iter->getNext()`. If there is a next element,
|
||||
|
@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
/// the memory for `iter` is freed and the return value is false.
|
||||
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value iter, Value ind, Value elemPtr) {
|
||||
Location loc = op->getLoc();
|
||||
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
|
||||
StringRef name;
|
||||
if (elemTp.isF64())
|
||||
|
@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
name = "getNextI8";
|
||||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
SmallVector<Value, 3> params;
|
||||
params.push_back(iter);
|
||||
params.push_back(ind);
|
||||
params.push_back(elemPtr);
|
||||
SmallVector<Value, 3> params{iter, ind, elemPtr};
|
||||
Type i1 = rewriter.getI1Type();
|
||||
auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
|
||||
auto call = rewriter.create<CallOp>(loc, i1, fn, params);
|
||||
auto call = createFuncCall(rewriter, op, name, i1, params,
|
||||
/*emitCInterface=*/true);
|
||||
return call.getResult(0);
|
||||
}
|
||||
|
||||
|
@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
|
|||
}
|
||||
Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
|
||||
Value zero = constantZero(rewriter, loc, elemTp);
|
||||
rewriter.create<linalg::FillOp>(loc, zero, mem).result();
|
||||
rewriter.create<linalg::FillOp>(loc, zero, mem);
|
||||
return mem;
|
||||
}
|
||||
|
||||
|
@ -754,9 +763,8 @@ public:
|
|||
matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
StringRef name = "delSparseTensor";
|
||||
TypeRange none;
|
||||
auto fn = getFunc(op, name, none, adaptor.getOperands());
|
||||
rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
|
||||
TypeRange noTp;
|
||||
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
@ -785,9 +793,8 @@ public:
|
|||
name = "sparsePointers8";
|
||||
else
|
||||
return failure();
|
||||
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
|
||||
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -814,9 +821,8 @@ public:
|
|||
name = "sparseIndices8";
|
||||
else
|
||||
return failure();
|
||||
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
|
||||
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -845,9 +851,8 @@ public:
|
|||
name = "sparseValuesI8";
|
||||
else
|
||||
return failure();
|
||||
auto fn = getFunc(op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
|
||||
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -863,8 +868,7 @@ public:
|
|||
// Finalize any pending insertions.
|
||||
StringRef name = "endInsert";
|
||||
TypeRange noTp;
|
||||
auto fn = getFunc(op, name, noTp, adaptor.getOperands());
|
||||
rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
|
||||
createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
|
||||
}
|
||||
rewriter.replaceOp(op, adaptor.getOperands());
|
||||
return success();
|
||||
|
@ -896,9 +900,8 @@ public:
|
|||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
TypeRange noTp;
|
||||
auto fn =
|
||||
getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
|
||||
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
|
||||
/*emitCInterface=*/true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue