forked from OSchip/llvm-project
[mlir][sparse] Factoring out helper functions for generating constants
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D111763
This commit is contained in:
parent
8e184f3d2a
commit
63d4fc9483
|
@ -81,6 +81,30 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
|
|||
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
|
||||
}
|
||||
|
||||
/// Generates a constant zero of the given type.
|
||||
inline static Value constantZero(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type t) {
|
||||
return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
|
||||
}
|
||||
|
||||
/// Generates a constant of `index` type.
|
||||
inline static Value constantIndex(ConversionPatternRewriter &rewriter,
|
||||
Location loc, unsigned i) {
|
||||
return rewriter.create<arith::ConstantIndexOp>(loc, i);
|
||||
}
|
||||
|
||||
/// Generates a constant of `i64` type.
|
||||
inline static Value constantI64(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int64_t i) {
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
|
||||
}
|
||||
|
||||
/// Generates a constant of `i32` type.
|
||||
inline static Value constantI32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int32_t i) {
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
|
||||
}
|
||||
|
||||
/// Returns integers of given width and values as a constant tensor.
|
||||
/// We cast the static shape into a dynamic shape to ensure that the
|
||||
/// method signature remains uniform across different tensor dimensions.
|
||||
|
@ -161,18 +185,14 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
|
||||
unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
|
||||
assert(primary);
|
||||
params.push_back(rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(secPtr)));
|
||||
params.push_back(rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(secInd)));
|
||||
params.push_back(rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(primary)));
|
||||
params.push_back(constantI64(rewriter, loc, secPtr));
|
||||
params.push_back(constantI64(rewriter, loc, secInd));
|
||||
params.push_back(constantI64(rewriter, loc, primary));
|
||||
// User action and pointer.
|
||||
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
|
||||
if (!ptr)
|
||||
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
|
||||
params.push_back(rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(action)));
|
||||
params.push_back(constantI32(rewriter, loc, action));
|
||||
params.push_back(ptr);
|
||||
// Generate the call to create new tensor.
|
||||
StringRef name = "newSparseTensor";
|
||||
|
@ -182,19 +202,13 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
return call.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a constant zero of the given type.
|
||||
static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type t) {
|
||||
return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t));
|
||||
}
|
||||
|
||||
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
|
||||
/// For floating types, we use the "unordered" comparator (i.e., returns
|
||||
/// true if `v` is NaN).
|
||||
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value v) {
|
||||
Type t = v.getType();
|
||||
Value zero = getZero(rewriter, loc, t);
|
||||
Value zero = constantZero(rewriter, loc, t);
|
||||
if (t.isa<FloatType>())
|
||||
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
|
||||
zero);
|
||||
|
@ -221,8 +235,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
|
|||
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
|
||||
unsigned i = 0;
|
||||
for (auto iv : ivs) {
|
||||
Value idx =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++));
|
||||
Value idx = constantIndex(rewriter, loc, i++);
|
||||
rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
|
||||
}
|
||||
return val;
|
||||
|
@ -289,8 +302,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
|
|||
unsigned rank) {
|
||||
Location loc = op->getLoc();
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
Value idx =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i));
|
||||
Value idx = constantIndex(rewriter, loc, i);
|
||||
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
|
||||
ValueRange{ivs[0], idx});
|
||||
val =
|
||||
|
@ -308,8 +320,7 @@ static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
|
|||
int64_t rank) {
|
||||
auto indexTp = rewriter.getIndexType();
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
|
||||
Value arg =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank));
|
||||
Value arg = constantIndex(rewriter, loc, rank);
|
||||
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
|
||||
}
|
||||
|
||||
|
@ -352,8 +363,7 @@ public:
|
|||
StringRef name = "sparseDimSize";
|
||||
SmallVector<Value, 2> params;
|
||||
params.push_back(adaptor.getOperands()[0]);
|
||||
params.push_back(rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(idx)));
|
||||
params.push_back(constantIndex(rewriter, op.getLoc(), idx));
|
||||
rewriter.replaceOpWithNewOp<CallOp>(
|
||||
op, resType, getFunc(op, name, resType, params), params);
|
||||
return success();
|
||||
|
@ -437,10 +447,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
|||
SmallVector<Value> lo;
|
||||
SmallVector<Value> hi;
|
||||
SmallVector<Value> st;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value zero = constantIndex(rewriter, loc, 0);
|
||||
Value one = constantIndex(rewriter, loc, 1);
|
||||
auto indicesValues = genSplitSparseConstant(rewriter, op, src);
|
||||
bool isCOOConstant = indicesValues.hasValue();
|
||||
Value indices;
|
||||
|
|
Loading…
Reference in New Issue