Fix parsing/printing of spv.globalVariable and spv._address_of

Change the prining/parsing of spv.globalVariable to print the type of
the variable after the ':' to be consistent with MLIR convention.
The spv._address_of should print the variable type after the ':'. It was
mistakenly printing the address of the return value. Add a (missing)
test that should have caught that.
Also move spv.globalVariable and spv._address_of tests to
structure-ops.mlir.

PiperOrigin-RevId: 264204686
This commit is contained in:
Mahesh Ravishankar 2019-08-19 11:38:53 -07:00 committed by A. Unique TensorFlower
parent ba0fa92524
commit 377bfb3a14
8 changed files with 236 additions and 207 deletions

View File

@ -461,7 +461,7 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
*printer << " @" << addressOfOp.variable();
// Print the type.
*printer << " : " << addressOfOp.pointer();
*printer << " : " << addressOfOp.pointer()->getType();
}
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
@ -676,9 +676,8 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
}
interfaceVars.push_back(var);
} while (!parser->parseOptionalComma());
state->attributes.push_back(
{parser->getBuilder().getIdentifier(kInterfaceAttrName),
parser->getBuilder().getArrayAttr(interfaceVars)});
state->addAttribute(kInterfaceAttrName,
parser->getBuilder().getArrayAttr(interfaceVars));
}
return success();
}
@ -748,18 +747,6 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
OperationState *state) {
// Parse variable type.
TypeAttr typeAttr;
auto loc = parser->getCurrentLocation();
if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName,
state->attributes)) {
return failure();
}
auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return parser->emitError(loc, "expected spv.ptr type");
}
// Parse variable name.
StringAttr nameAttr;
if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
@ -781,6 +768,16 @@ static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
return failure();
}
Type type;
auto loc = parser->getCurrentLocation();
if (parser->parseColonType(type)) {
return failure();
}
if (!type.isa<spirv::PointerType>()) {
return parser->emitError(loc, "expected spv.ptr type");
}
state->addAttribute(kTypeAttrName, parser->getBuilder().getTypeAttr(type));
return success();
}
@ -790,10 +787,6 @@ static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
spirv::attributeName<spirv::StorageClass>()};
*printer << spirv::GlobalVariableOp::getOperationName();
// Print variable type.
*printer << " " << varOp.type();
elidedAttrs.push_back(kTypeAttrName);
// Print variable name.
*printer << " @" << varOp.sym_name();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
@ -804,7 +797,10 @@ static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
<< ")";
elidedAttrs.push_back(kInitializerAttrName);
}
elidedAttrs.push_back(kTypeAttrName);
printVariableDecorations(op, printer, elidedAttrs);
*printer << " : " << varOp.type();
}
static LogicalResult verify(spirv::GlobalVariableOp varOp) {

View File

@ -1,8 +1,8 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
// CHECK: spv.module "Logical" "VulkanKHR" {
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, StorageBuffer> [[VAR1:@.*]] bind(0, 0)
// CHECK-NEXT: spv.globalVariable !spv.ptr<!spv.array<12 x f32>, StorageBuffer> [[VAR2:@.*]] bind(0, 1)
// CHECK-NEXT: spv.globalVariable [[VAR1:@.*]] bind(0, 0) : !spv.ptr<f32, StorageBuffer>
// CHECK-NEXT: spv.globalVariable [[VAR2:@.*]] bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
// CHECK-NEXT: func @kernel_1
// CHECK-NEXT: spv.Return
// CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]

View File

@ -2,12 +2,12 @@
func @spirv_loadstore() -> () {
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var2
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Output> @var3
// CHECK: spv.globalVariable @var2 : !spv.ptr<f32, Input>
// CHECK-NEXT: spv.globalVariable @var3 : !spv.ptr<f32, Output>
// CHECK-NEXT: func @noop({{%.*}}: !spv.ptr<f32, Input>, {{%.*}}: !spv.ptr<f32, Output>)
// CHECK: spv.EntryPoint "GLCompute" @noop, @var2, @var3
spv.globalVariable !spv.ptr<f32, Input> @var2
spv.globalVariable !spv.ptr<f32, Output> @var3
spv.globalVariable @var2 : !spv.ptr<f32, Input>
spv.globalVariable @var3 : !spv.ptr<f32, Output>
func @noop(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) -> () {
spv.Return
}

View File

@ -0,0 +1,18 @@
// RUN: mlir-translate -serialize-spirv %s
// TODO: This example doesn't work on deserialization since constants
// are always added to module scope and need to be materialized into
// function scope. So for now just run the serialization.
func @spirv_global_vars() -> () {
spv.module "Logical" "VulkanKHR" {
spv.globalVariable @globalInvocationID built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
func @foo() {
%0 = spv._address_of @globalInvocationID : !spv.ptr<vector<3xi32>, Input>
%1 = spv.constant 0: i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<vector<3xi32>, Input>
spv.Return
}
}
return
}

View File

@ -1,15 +1,15 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var0 bind(1, 0)
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Output> @var1 bind(0, 1)
// CHECK-NEXT: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 built_in("GlobalInvocationId")
// CHECK-NEXT: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var3 built_in("GlobalInvocationId")
// CHECK: spv.globalVariable @var0 bind(1, 0) : !spv.ptr<f32, Input>
// CHECK-NEXT: spv.globalVariable @var1 bind(0, 1) : !spv.ptr<f32, Output>
// CHECK-NEXT: spv.globalVariable @var2 built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-NEXT: spv.globalVariable @var3 built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var0 bind(1, 0)
spv.globalVariable !spv.ptr<f32, Output> @var1 bind(0, 1)
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 {built_in = "GlobalInvocationId"}
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var3 built_in("GlobalInvocationId")
spv.globalVariable @var0 bind(1, 0) : !spv.ptr<f32, Input>
spv.globalVariable @var1 bind(0, 1) : !spv.ptr<f32, Output>
spv.globalVariable @var2 {built_in = "GlobalInvocationId"} : !spv.ptr<vector<3xi32>, Input>
spv.globalVariable @var3 built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
return
}

View File

@ -2,10 +2,10 @@
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var1
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Input> @var2 initializer(@var1) bind(1, 0)
spv.globalVariable !spv.ptr<f32, Input> @var1
spv.globalVariable !spv.ptr<f32, Input> @var2 initializer(@var1) bind(1, 0)
// CHECK: spv.globalVariable @var1 : !spv.ptr<f32, Input>
// CHECK-NEXT: spv.globalVariable @var2 initializer(@var1) bind(1, 0) : !spv.ptr<f32, Input>
spv.globalVariable @var1 : !spv.ptr<f32, Input>
spv.globalVariable @var2 initializer(@var1) bind(1, 0) : !spv.ptr<f32, Input>
}
return
}

View File

@ -117,19 +117,6 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var1
func @access_chain() -> () {
%0 = spv.constant 1: i32
%1 = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@ -274,88 +261,6 @@ func @composite_extract_result_type_mismatch(%arg0: !spv.array<4xf32>) -> i32 {
// -----
//===----------------------------------------------------------------------===//
// spv.EntryPoint
//===----------------------------------------------------------------------===//
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// CHECK: spv.EntryPoint "GLCompute" @do_nothing
spv.EntryPoint "GLCompute" @do_nothing
}
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var2
spv.globalVariable !spv.ptr<f32, Output> @var3
func @do_something(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) -> () {
%1 = spv.Load "Input" %arg0 : f32
spv.Store "Output" %arg1, %1 : f32
spv.Return
}
// CHECK: spv.EntryPoint "GLCompute" @do_something, @var2, @var3
spv.EntryPoint "GLCompute" @do_something, @var2, @var3
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{invalid kind of constant specified}}
spv.EntryPoint "GLCompute" "do_nothing"
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{function 'do_something' not found in 'spv.module'}}
spv.EntryPoint "GLCompute" @do_something
}
/// TODO(ravishankarm) : Add a test that verifies an error is thrown
/// when interface entries of EntryPointOp are not
/// spv.Variables. There is currently no other op that has a spv.ptr
/// return type
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
// expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.EntryPoint "GLCompute" @do_something
}
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @do_nothing
// expected-error @+1 {{duplicate of a previous EntryPointOp}}
spv.EntryPoint "GLCompute" @do_nothing
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @do_nothing
// expected-error @+1 {{custom op 'spv.EntryPoint' invalid execution_model attribute specification: "ContractionOff"}}
spv.EntryPoint "ContractionOff" @do_nothing
}
// -----
//===----------------------------------------------------------------------===//
// spv.ExecutionMode
//===----------------------------------------------------------------------===//
@ -391,74 +296,6 @@ spv.module "Logical" "VulkanKHR" {
// -----
//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var0
spv.globalVariable !spv.ptr<f32, Input> @var0
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.Variable init(%0) : !spv.ptr<f32, Private>
// spv.globalVariable !spv.ptr<f32, Private> @var1 init(%0)
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Uniform> @var0 bind(1, 2)
spv.globalVariable !spv.ptr<f32, Uniform> @var0 bind(1, 2)
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.globalVariable !spv.ptr<f32, Private> @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
// spv.globalVariable !spv.ptr<f32, Private> @var1 initializer(%0) {binding = 5 : i32} :
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var1 built_in("GlobalInvocationID")
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var1 built_in("GlobalInvocationID")
// CHECK: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 built_in("GlobalInvocationID")
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 {built_in = "GlobalInvocationID"}
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{expected spv.ptr type}}
spv.globalVariable f32 @var0
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{op initializer must be result of a spv.globalVariable op}}
spv.globalVariable !spv.ptr<f32, Private> @var0 initializer(@var1)
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{storage class cannot be 'Generic'}}
spv.globalVariable !spv.ptr<f32, Generic> @var0
}
// -----
spv.module "Logical" "VulkanKHR" {
func @foo() {
// expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.FAdd
//===----------------------------------------------------------------------===//
@ -750,7 +587,7 @@ func @aligned_load_incorrect_attributes() -> () {
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.globalVariable @var0 : !spv.ptr<f32, Input>
// CHECK_LABEL: @simple_load
func @simple_load() -> () {
// CHECK: spv.Load "Input" {{%.*}} : f32
@ -1011,7 +848,7 @@ func @aligned_store_incorrect_attributes(%arg0 : f32) -> () {
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.globalVariable @var0 : !spv.ptr<f32, Input>
func @simple_store(%arg0 : f32) -> () {
%0 = spv._address_of @var0 : !spv.ptr<f32, Input>
// CHECK: spv.Store "Input" {{%.*}}, {{%.*}} : f32

View File

@ -4,14 +4,41 @@
// spv._address_of
//===----------------------------------------------------------------------===//
spv.module "Logical" "GLSL450" {
spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var
// expected-error @+1 {{op must appear in a 'func' block}}
%1 = spv._address_of @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
spv.module "Logical" "VulkanKHR" {
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
func @access_chain() -> () {
%0 = spv.constant 1: i32
// CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
// CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
%1 = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
spv.Return
}
}
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
func @foo() -> () {
// expected-error @+1{{expected spv.globalVariable symbol}}
%0 = spv._address_of @var2 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
}
}
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
func @foo() -> () {
// expected-error @+1{{mismatch in result type and type of global variable referenced}}
%0 = spv._address_of @var1 : !spv.ptr<f32, Input>
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
@ -62,6 +89,157 @@ func @value_result_type_mismatch() -> () {
%0 = "spv.constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
}
// -----
//===----------------------------------------------------------------------===//
// spv.EntryPoint
//===----------------------------------------------------------------------===//
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// CHECK: spv.EntryPoint "GLCompute" @do_nothing
spv.EntryPoint "GLCompute" @do_nothing
}
spv.module "Logical" "VulkanKHR" {
spv.globalVariable @var2 : !spv.ptr<f32, Input>
spv.globalVariable @var3 : !spv.ptr<f32, Output>
func @do_something(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) -> () {
%1 = spv.Load "Input" %arg0 : f32
spv.Store "Output" %arg1, %1 : f32
spv.Return
}
// CHECK: spv.EntryPoint "GLCompute" @do_something, @var2, @var3
spv.EntryPoint "GLCompute" @do_something, @var2, @var3
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{invalid kind of constant specified}}
spv.EntryPoint "GLCompute" "do_nothing"
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{function 'do_something' not found in 'spv.module'}}
spv.EntryPoint "GLCompute" @do_something
}
/// TODO(ravishankarm) : Add a test that verifies an error is thrown
/// when interface entries of EntryPointOp are not
/// spv.Variables. There is currently no other op that has a spv.ptr
/// return type
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
// expected-error @+1 {{'spv.EntryPoint' op failed to verify that op must appear in a 'spv.module' block}}
spv.EntryPoint "GLCompute" @do_something
}
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @do_nothing
// expected-error @+1 {{duplicate of a previous EntryPointOp}}
spv.EntryPoint "GLCompute" @do_nothing
}
// -----
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @do_nothing
// expected-error @+1 {{custom op 'spv.EntryPoint' invalid execution_model attribute specification: "ContractionOff"}}
spv.EntryPoint "ContractionOff" @do_nothing
}
// -----
//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable @var0 : !spv.ptr<f32, Input>
spv.globalVariable @var0 : !spv.ptr<f32, Input>
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.Variable init(%0) : !spv.ptr<f32, Private>
// spv.globalVariable @var1 init(%0) : !spv.ptr<f32, Private>
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.globalVariable @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
// spv.globalVariable @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable @var1 built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
spv.globalVariable @var1 built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
// CHECK: spv.globalVariable @var2 built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
spv.globalVariable @var2 {built_in = "GlobalInvocationID"} : !spv.ptr<vector<3xi32>, Input>
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{expected spv.ptr type}}
spv.globalVariable @var0 : f32
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{op initializer must be result of a spv.globalVariable op}}
spv.globalVariable @var0 initializer(@var1) : !spv.ptr<f32, Private>
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{storage class cannot be 'Generic'}}
spv.globalVariable @var0 : !spv.ptr<f32, Generic>
}
// -----
spv.module "Logical" "VulkanKHR" {
func @foo() {
// expected-error @+1 {{op failed to verify that op must appear in a 'spv.module' block}}
spv.globalVariable @var0 : !spv.ptr<f32, Input>
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//