diff --git a/mlir/include/mlir/SPIRV/SPIRVBase.td b/mlir/include/mlir/SPIRV/SPIRVBase.td index aff5e1a34962..64c692f2ce90 100644 --- a/mlir/include/mlir/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/SPIRV/SPIRVBase.td @@ -62,10 +62,44 @@ def SPV_Dialect : Dialect { // SPIR-V type definitions //===----------------------------------------------------------------------===// +def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; +def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; +def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; + +// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types +// for the definition of the following types and type categories. + +def SPV_Void : TypeAlias; +def SPV_Bool : IntOfWidths<[1]>; +def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>; +def SPV_Float : FloatOfWidths<[16, 32, 64]>; +def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>; +// Component type check is done in the type parser for the following SPIR-V +// dialect-specific types so we use "Any" here. +def SPV_AnyPtr : Type; +def SPV_AnyArray : Type; +def SPV_AnyRTArray : Type; + +def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; +def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; +def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>; +def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>; +def SPV_Type : AnyTypeOf<[ + SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray + ]>; + class SPV_ScalarOrVectorOf : Type.predicate]>, "scalar/vector of " # type.description>; +// TODO(antiagainst): Use a more appropriate way to model optional operands +class SPV_Optional : Variadic; + +//===----------------------------------------------------------------------===// +// SPIR-V enum definitions +//===----------------------------------------------------------------------===// + // Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY! def SPV_AM_Logical : EnumAttrCase<"Logical", 0>; diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index cab3820270bb..a58a17968ab3 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -48,8 +48,8 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { }]; let arguments = (ins - SPV_ScalarOrVectorOf:$operand1, - SPV_ScalarOrVectorOf:$operand2 + SPV_ScalarOrVectorOf:$operand1, + SPV_ScalarOrVectorOf:$operand2 ); let results = (outs diff --git a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td index 6e084b91d69a..ad86fb200e34 100644 --- a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td @@ -79,7 +79,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> { // Custom parser and printer implemented by static functions in SPVOps.cpp let parser = [{ return parseModule(parser, result); }]; - let printer = [{ printModule(getOperation(), p); }]; + let printer = [{ printModule(*this, p); }]; let verifier = [{ return verifyModule(*this); }]; } diff --git a/mlir/lib/SPIRV/SPIRVOps.cpp b/mlir/lib/SPIRV/SPIRVOps.cpp index aa77fdd6ea59..2579f173226e 100644 --- a/mlir/lib/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/SPIRV/SPIRVOps.cpp @@ -85,8 +85,10 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) { return success(); } -static ParseResult printModule(Operation *op, OpAsmPrinter *printer) { - *printer << op->getName(); +static ParseResult printModule(spirv::ModuleOp moduleOp, + OpAsmPrinter *printer) { + auto *op = moduleOp.getOperation(); + *printer << spirv::ModuleOp::getOperationName(); printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); *printer << " attributes"; @@ -113,8 +115,6 @@ static LogicalResult verifyModule(spirv::ModuleOp moduleOp) { for (auto &block : funcOp) for (auto &op : block) { - // TODO(antiagainst): verify that return ops have the same type as the - // enclosing function if (op.getDialect() == dialect) continue; diff --git a/mlir/test/SPIRV/ops.mlir b/mlir/test/SPIRV/ops.mlir index 225a76e349f0..bfc5f586f515 100644 --- a/mlir/test/SPIRV/ops.mlir +++ b/mlir/test/SPIRV/ops.mlir @@ -19,15 +19,23 @@ func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func @fmul_i32(%arg: i32) -> i32 { - // expected-error @+1 {{must be scalar/vector of floating-point}} + // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} %0 = spv.FMul %arg, %arg : i32 return %0 : i32 } // ----- +func @fmul_bf16(%arg: bf16) -> bf16 { + // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} + %0 = spv.FMul %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{must be scalar/vector of floating-point}} + // expected-error @+1 {{must be scalar/vector of 16/32/64-bit float}} %0 = spv.FMul %arg, %arg : tensor<4xf32> return %0 : tensor<4xf32> }