forked from OSchip/llvm-project
[spirv] Define common types using op definition spec
This CL also tightens spv.FMul to only accept 16/32/64-bit floats. PiperOrigin-RevId: 253649352
This commit is contained in:
parent
cf74e41277
commit
1d4c040966
|
@ -62,10 +62,44 @@ def SPV_Dialect : Dialect {
|
|||
// SPIR-V type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
||||
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
||||
|
||||
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
|
||||
// for the definition of the following types and type categories.
|
||||
|
||||
def SPV_Void : TypeAlias<NoneType, "void type">;
|
||||
def SPV_Bool : IntOfWidths<[1]>;
|
||||
def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
|
||||
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
|
||||
def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
|
||||
// Component type check is done in the type parser for the following SPIR-V
|
||||
// dialect-specific types so we use "Any" here.
|
||||
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
|
||||
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
|
||||
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
|
||||
|
||||
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
|
||||
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
|
||||
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
|
||||
def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>;
|
||||
def SPV_Type : AnyTypeOf<[
|
||||
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
|
||||
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray
|
||||
]>;
|
||||
|
||||
class SPV_ScalarOrVectorOf<Type type> :
|
||||
Type<Or<[type.predicate, VectorOf<[type]>.predicate]>,
|
||||
"scalar/vector of " # type.description>;
|
||||
|
||||
// TODO(antiagainst): Use a more appropriate way to model optional operands
|
||||
class SPV_Optional<Type type> : Variadic<type>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V enum definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
||||
def SPV_AM_Logical : EnumAttrCase<"Logical", 0>;
|
||||
|
|
|
@ -48,8 +48,8 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScalarOrVectorOf<AnyFloat>:$operand1,
|
||||
SPV_ScalarOrVectorOf<AnyFloat>:$operand2
|
||||
SPV_ScalarOrVectorOf<SPV_Float>:$operand1,
|
||||
SPV_ScalarOrVectorOf<SPV_Float>:$operand2
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
|
@ -79,7 +79,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
|
|||
|
||||
// Custom parser and printer implemented by static functions in SPVOps.cpp
|
||||
let parser = [{ return parseModule(parser, result); }];
|
||||
let printer = [{ printModule(getOperation(), p); }];
|
||||
let printer = [{ printModule(*this, p); }];
|
||||
|
||||
let verifier = [{ return verifyModule(*this); }];
|
||||
}
|
||||
|
|
|
@ -85,8 +85,10 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static ParseResult printModule(Operation *op, OpAsmPrinter *printer) {
|
||||
*printer << op->getName();
|
||||
static ParseResult printModule(spirv::ModuleOp moduleOp,
|
||||
OpAsmPrinter *printer) {
|
||||
auto *op = moduleOp.getOperation();
|
||||
*printer << spirv::ModuleOp::getOperationName();
|
||||
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
*printer << " attributes";
|
||||
|
@ -113,8 +115,6 @@ static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
|
|||
|
||||
for (auto &block : funcOp)
|
||||
for (auto &op : block) {
|
||||
// TODO(antiagainst): verify that return ops have the same type as the
|
||||
// enclosing function
|
||||
if (op.getDialect() == dialect)
|
||||
continue;
|
||||
|
||||
|
|
|
@ -19,15 +19,23 @@ func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> {
|
|||
// -----
|
||||
|
||||
func @fmul_i32(%arg: i32) -> i32 {
|
||||
// expected-error @+1 {{must be scalar/vector of floating-point}}
|
||||
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
|
||||
%0 = spv.FMul %arg, %arg : i32
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @fmul_bf16(%arg: bf16) -> bf16 {
|
||||
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
|
||||
%0 = spv.FMul %arg, %arg : bf16
|
||||
return %0 : bf16
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// expected-error @+1 {{must be scalar/vector of floating-point}}
|
||||
// expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}}
|
||||
%0 = spv.FMul %arg, %arg : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue