forked from OSchip/llvm-project
Add lowering of constant ops to SPIR-V.
The lowering is specified as a pattern and is done only if the result is a SPIR-V scalar type or vector type. Handling ConstantOp with index return type needs special handling since SPIR-V dialect does not have index types. Based on the bitwidth of the attribute value, either i32 or i64 is chosen. Other constant lowerings are left as a TODO. PiperOrigin-RevId: 274056805
This commit is contained in:
parent
736f80d0dd
commit
28d7f9c052
|
@ -31,6 +31,21 @@ using namespace mlir;
|
|||
// Type Conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type convertIndexType(MLIRContext *context) {
|
||||
// Convert to 32-bit integers for now. Might need a way to control this in
|
||||
// future.
|
||||
// TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
|
||||
// this some support is needed in SPIR-V dialect for Conversion
|
||||
// instructions. The Vulkan spec requires the builtins like
|
||||
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
|
||||
// SExtended to 64-bit for index computations.
|
||||
return IntegerType::get(32, context);
|
||||
}
|
||||
|
||||
static Type convertIndexType(IndexType t) {
|
||||
return convertIndexType(t.getContext());
|
||||
}
|
||||
|
||||
static Type basicTypeConversion(Type t) {
|
||||
// Check if the type is SPIR-V supported. If so return the type.
|
||||
if (spirv::SPIRVDialect::isValidType(t)) {
|
||||
|
@ -38,8 +53,7 @@ static Type basicTypeConversion(Type t) {
|
|||
}
|
||||
|
||||
if (auto indexType = t.dyn_cast<IndexType>()) {
|
||||
// Return I32 for index types.
|
||||
return IntegerType::get(32, t.getContext());
|
||||
return convertIndexType(indexType);
|
||||
}
|
||||
|
||||
if (auto memRefType = t.dyn_cast<MemRefType>()) {
|
||||
|
@ -122,9 +136,9 @@ static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
|
|||
// Insert the addressOf and load instructions, to get back the converted value
|
||||
// type.
|
||||
auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
|
||||
auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
|
||||
rewriter.getIntegerType(32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
auto indexType = convertIndexType(funcOp.getContext());
|
||||
auto zero = rewriter.create<spirv::ConstantOp>(
|
||||
funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
|
||||
auto accessChain = rewriter.create<spirv::AccessChainOp>(
|
||||
funcOp.getLoc(), addressOf.pointer(), zero.constant());
|
||||
// If the original argument is a tensor/memref type, the value is not
|
||||
|
@ -269,6 +283,46 @@ LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
|
|||
|
||||
namespace {
|
||||
|
||||
/// Convert constant operation with IndexType return to SPIR-V constant
|
||||
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
|
||||
/// special handling to make sure the result type and the type of the value
|
||||
/// attribute are consistent.
|
||||
class ConstantIndexOpConversion final : public ConversionPattern {
|
||||
public:
|
||||
ConstantIndexOpConversion(MLIRContext *context)
|
||||
: ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto constIndexOp = cast<ConstantOp>(op);
|
||||
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
// The attribute has index type. Get the integer value and create a new
|
||||
// IntegerAttr.
|
||||
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
|
||||
if (!constAttr) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Use the bitwidth set in the value attribute to decide the result type of
|
||||
// the SPIR-V constant operation since SPIR-V does not support index types.
|
||||
auto constVal = constAttr.getValue();
|
||||
auto constValType = constAttr.getType().dyn_cast<IndexType>();
|
||||
if (!constValType) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto spirvConstType = convertIndexType(constValType);
|
||||
auto spirvConstVal =
|
||||
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
|
||||
auto spirvConstantOp = rewriter.create<spirv::ConstantOp>(
|
||||
op->getLoc(), spirvConstType, spirvConstVal);
|
||||
rewriter.replaceOp(op, spirvConstantOp.constant(), {});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
|
||||
/// for this. If the integer operation is on variables of IndexType, the type of
|
||||
/// the return value of the replacement operation differs from that of the
|
||||
|
@ -375,7 +429,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
OwningRewritePatternList &patterns) {
|
||||
populateWithGenerated(context, &patterns);
|
||||
// Add the return op conversion.
|
||||
patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
|
||||
patterns.insert<ConstantIndexOpConversion,
|
||||
IntegerOpConversion<AddIOp, spirv::IAddOp>,
|
||||
IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
|
||||
ReturnToSPIRVConversion, StoreOpConversion>(context);
|
||||
}
|
||||
|
|
|
@ -31,4 +31,10 @@ class BinaryOpPattern<Op src, Op tgt> :
|
|||
def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
|
||||
def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
|
||||
|
||||
// Constant Op
|
||||
// TODO(ravishankarm): Handle lowering other constant types.
|
||||
def : Pat<(ConstantOp:$result $valueAttr),
|
||||
(SPV_ConstantOp $valueAttr),
|
||||
[(SPV_ScalarOrVector $result)]>;
|
||||
|
||||
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
|
||||
|
|
|
@ -44,3 +44,17 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
|
|||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @constval
|
||||
func @constval() {
|
||||
// CHECK: spv.constant true
|
||||
%0 = constant true
|
||||
// CHECK: spv.constant 42 : i64
|
||||
%1 = constant 42
|
||||
// CHECK: spv.constant {{[0-9]*\.[0-9]*e?-?[0-9]*}} : f32
|
||||
%2 = constant 0.5 : f32
|
||||
// CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
|
||||
%3 = constant dense<[2, 3]> : vector<2xi32>
|
||||
// CHECK: spv.constant 1 : i32
|
||||
%4 = constant 1 : index
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue