[spirv] Define common types using op definition spec

This CL also tightens spv.FMul to only accept 16/32/64-bit floats.

PiperOrigin-RevId: 253649352
This commit is contained in:
Lei Zhang 2019-06-17 13:29:06 -07:00 committed by Mehdi Amini
parent cf74e41277
commit 1d4c040966
5 changed files with 51 additions and 9 deletions

View File

@ -62,10 +62,44 @@ def SPV_Dialect : Dialect {
// SPIR-V type definitions
//===----------------------------------------------------------------------===//
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
def SPV_Void : TypeAlias<NoneType, "void type">;
def SPV_Bool : IntOfWidths<[1]>;
def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>;
def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray
]>;
class SPV_ScalarOrVectorOf<Type type> :
Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
"scalar/vector of " # type.description>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
//===----------------------------------------------------------------------===//
// SPIR-V enum definitions
//===----------------------------------------------------------------------===//
// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_AM_Logical : EnumAttrCase<"Logical", 0>;

View File

@ -48,8 +48,8 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
}];
let arguments = (ins
SPV_ScalarOrVectorOf<AnyFloat>:$operand1,
SPV_ScalarOrVectorOf<AnyFloat>:$operand2
SPV_ScalarOrVectorOf<SPV_Float>:$operand1,
SPV_ScalarOrVectorOf<SPV_Float>:$operand2
);
let results = (outs

View File

@ -79,7 +79,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
// Custom parser and printer implemented by static functions in SPVOps.cpp
let parser = [{ return parseModule(parser, result); }];
let printer = [{ printModule(getOperation(), p); }];
let printer = [{ printModule(*this, p); }];
let verifier = [{ return verifyModule(*this); }];
}

View File

@ -85,8 +85,10 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
return success();
}
static ParseResult printModule(Operation *op, OpAsmPrinter *printer) {
*printer << op->getName();
static ParseResult printModule(spirv::ModuleOp moduleOp,
OpAsmPrinter *printer) {
auto *op = moduleOp.getOperation();
*printer << spirv::ModuleOp::getOperationName();
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
*printer << " attributes";
@ -113,8 +115,6 @@ static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
for (auto &block : funcOp)
for (auto &op : block) {
// TODO(antiagainst): verify that return ops have the same type as the
// enclosing function
if (op.getDialect() == dialect)
continue;

View File

@ -19,15 +19,23 @@ func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> {
// -----
func @fmul_i32(%arg: i32) -> i32 {
// expected-error @+1 {{must be scalar/vector of floating-point}}
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
%0 = spv.FMul %arg, %arg : i32
return %0 : i32
}
// -----
func @fmul_bf16(%arg: bf16) -> bf16 {
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
%0 = spv.FMul %arg, %arg : bf16
return %0 : bf16
}
// -----
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
// expected-error @+1 {{must be scalar/vector of floating-point}}
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
%0 = spv.FMul %arg, %arg : tensor<4xf32>
return %0 : tensor<4xf32>
}