[mlir][sparse] Factoring out helper functions for generating constants

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D111763
This commit is contained in:
wren romano 2021-10-13 16:03:26 -07:00
parent 8e184f3d2a
commit 63d4fc9483
1 changed files with 35 additions and 27 deletions

View File

@ -81,6 +81,30 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
}
/// Generates a constant zero of the given type.
inline static Value constantZero(ConversionPatternRewriter &rewriter,
Location loc, Type t) {
return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
}
/// Generates a constant of `index` type.
inline static Value constantIndex(ConversionPatternRewriter &rewriter,
Location loc, unsigned i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i);
}
/// Generates a constant of `i64` type.
inline static Value constantI64(ConversionPatternRewriter &rewriter,
Location loc, int64_t i) {
return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
}
/// Generates a constant of `i32` type.
inline static Value constantI32(ConversionPatternRewriter &rewriter,
Location loc, int32_t i) {
return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
}
/// Returns integers of given width and values as a constant tensor.
/// We cast the static shape into a dynamic shape to ensure that the
/// method signature remains uniform across different tensor dimensions.
@ -161,18 +185,14 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
assert(primary);
params.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(secPtr)));
params.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(secInd)));
params.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(primary)));
params.push_back(constantI64(rewriter, loc, secPtr));
params.push_back(constantI64(rewriter, loc, secInd));
params.push_back(constantI64(rewriter, loc, primary));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
params.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(action)));
params.push_back(constantI32(rewriter, loc, action));
params.push_back(ptr);
// Generate the call to create new tensor.
StringRef name = "newSparseTensor";
@ -182,19 +202,13 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
return call.getResult(0);
}
/// Generates a constant zero of the given type.
static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
Type t) {
return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t));
}
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
/// For floating types, we use the "unordered" comparator (i.e., returns
/// true if `v` is NaN).
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
Value v) {
Type t = v.getType();
Value zero = getZero(rewriter, loc, t);
Value zero = constantZero(rewriter, loc, t);
if (t.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
@ -221,8 +235,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
unsigned i = 0;
for (auto iv : ivs) {
Value idx =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++));
Value idx = constantIndex(rewriter, loc, i++);
rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
}
return val;
@ -289,8 +302,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
unsigned rank) {
Location loc = op->getLoc();
for (unsigned i = 0; i < rank; i++) {
Value idx =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i));
Value idx = constantIndex(rewriter, loc, i);
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
ValueRange{ivs[0], idx});
val =
@ -308,8 +320,7 @@ static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
int64_t rank) {
auto indexTp = rewriter.getIndexType();
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
Value arg =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank));
Value arg = constantIndex(rewriter, loc, rank);
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
}
@ -352,8 +363,7 @@ public:
StringRef name = "sparseDimSize";
SmallVector<Value, 2> params;
params.push_back(adaptor.getOperands()[0]);
params.push_back(rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(idx)));
params.push_back(constantIndex(rewriter, op.getLoc(), idx));
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, params), params);
return success();
@ -437,10 +447,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value> lo;
SmallVector<Value> hi;
SmallVector<Value> st;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value zero = constantIndex(rewriter, loc, 0);
Value one = constantIndex(rewriter, loc, 1);
auto indicesValues = genSplitSparseConstant(rewriter, op, src);
bool isCOOConstant = indicesValues.hasValue();
Value indices;