Add lowering of constant ops to SPIR-V.

The lowering is specified as a pattern and is done only if the result
is a SPIR-V scalar type or vector type.
Handling ConstantOp with index return type needs special handling
since SPIR-V dialect does not have index types. Based on the bitwidth
of the attribute value, either i32 or i64 is chosen.
Other constant lowerings are left as a TODO.

PiperOrigin-RevId: 274056805
This commit is contained in:
Mahesh Ravishankar 2019-10-10 15:51:35 -07:00 committed by Jacques Pienaar
parent 736f80d0dd
commit 28d7f9c052
3 changed files with 81 additions and 6 deletions

View File

@ -31,6 +31,21 @@ using namespace mlir;
// Type Conversion
//===----------------------------------------------------------------------===//
static Type convertIndexType(MLIRContext *context) {
// Convert to 32-bit integers for now. Might need a way to control this in
// future.
// TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
// this some support is needed in SPIR-V dialect for Conversion
// instructions. The Vulkan spec requires the builtins like
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
// SExtended to 64-bit for index computations.
return IntegerType::get(32, context);
}
static Type convertIndexType(IndexType t) {
return convertIndexType(t.getContext());
}
static Type basicTypeConversion(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirv::SPIRVDialect::isValidType(t)) {
@ -38,8 +53,7 @@ static Type basicTypeConversion(Type t) {
}
if (auto indexType = t.dyn_cast<IndexType>()) {
// Return I32 for index types.
return IntegerType::get(32, t.getContext());
return convertIndexType(indexType);
}
if (auto memRefType = t.dyn_cast<MemRefType>()) {
@ -122,9 +136,9 @@ static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
// Insert the addressOf and load instructions, to get back the converted value
// type.
auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
rewriter.getIntegerType(32),
rewriter.getI32IntegerAttr(0));
auto indexType = convertIndexType(funcOp.getContext());
auto zero = rewriter.create<spirv::ConstantOp>(
funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
auto accessChain = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), addressOf.pointer(), zero.constant());
// If the original argument is a tensor/memref type, the value is not
@ -269,6 +283,46 @@ LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
namespace {
/// Convert constant operation with IndexType return to SPIR-V constant
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
/// attribute are consistent.
class ConstantIndexOpConversion final : public ConversionPattern {
public:
ConstantIndexOpConversion(MLIRContext *context)
: ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto constIndexOp = cast<ConstantOp>(op);
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
return matchFailure();
}
// The attribute has index type. Get the integer value and create a new
// IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
return matchFailure();
}
// Use the bitwidth set in the value attribute to decide the result type of
// the SPIR-V constant operation since SPIR-V does not support index types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
return matchFailure();
}
auto spirvConstType = convertIndexType(constValType);
auto spirvConstVal =
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
auto spirvConstantOp = rewriter.create<spirv::ConstantOp>(
op->getLoc(), spirvConstType, spirvConstVal);
rewriter.replaceOp(op, spirvConstantOp.constant(), {});
return matchSuccess();
}
};
/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
/// for this. If the integer operation is on variables of IndexType, the type of
/// the return value of the replacement operation differs from that of the
@ -375,7 +429,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
patterns.insert<ConstantIndexOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
ReturnToSPIRVConversion, StoreOpConversion>(context);
}

View File

@ -31,4 +31,10 @@ class BinaryOpPattern<Op src, Op tgt> :
def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
// Constant Op
// TODO(ravishankarm): Handle lowering other constant types.
def : Pat<(ConstantOp:$result $valueAttr),
(SPV_ConstantOp $valueAttr),
[(SPV_ScalarOrVector $result)]>;
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD

View File

@ -44,3 +44,17 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
return %0 : tensor<4xf32>
}
// CHECK-LABEL: @constval
func @constval() {
// CHECK: spv.constant true
%0 = constant true
// CHECK: spv.constant 42 : i64
%1 = constant 42
// CHECK: spv.constant {{[0-9]*\.[0-9]*e?-?[0-9]*}} : f32
%2 = constant 0.5 : f32
// CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
%3 = constant dense<[2, 3]> : vector<2xi32>
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
return
}