[mlir] harden result type verification in llvm.call

The verifier of the llvm.call operation was not checking for mismatches between
the number of operation results and the number of results in the signature of
the callee. Furthermore, it was possible to construct an llvm.call operation
producing an SSA value of !llvm.void type, which should not exist. Add the
verification and treat !llvm.void result type as absence of call results.
Update the GPU conversions to LLVM that were mistakenly assuming that it was
fine for llvm.call to produce values of !llvm.void type and ensure these calls
do not produce results.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D106937
This commit is contained in:
Alex Zinenko 2021-07-28 10:23:06 +02:00
parent 6e8660a7d6
commit c1f719d1a7
5 changed files with 64 additions and 24 deletions

View File

@ -353,9 +353,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder::atBlockEnd(module.getBody())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
return builder.create<LLVM::CallOp>(
loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
builder.getSymbolRefAttr(function), arguments);
return builder.create<LLVM::CallOp>(loc, function, arguments);
}
// Returns whether all operands are of LLVM type.

View File

@ -248,7 +248,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
}
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
loc, TypeRange{getVoidType()},
loc, TypeRange(),
builder.getSymbolRefAttr(
StringRef(symbolName.data(), symbolName.size())),
ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
@ -396,32 +396,31 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// Create call to `setBinaryShader` runtime function with the given pointer to
// SPIR-V binary and binary size.
builder.create<LLVM::CallOp>(
loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader),
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
Value entryPointName = createEntryPointNameConstant(
spirvAttributes.second.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kSetEntryPoint),
ValueRange{vulkanRuntime, entryPointName});
// Create number of local workgroup for each dimension.
builder.create<LLVM::CallOp>(
loc, TypeRange{getVoidType()},
builder.getSymbolRefAttr(kSetNumWorkGroups),
loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups),
ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
cInterfaceVulkanLaunchCallOp.getOperand(1),
cInterfaceVulkanLaunchCallOp.getOperand(2)});
// Create call to `runOnVulkan` runtime function.
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kRunOnVulkan),
ValueRange{vulkanRuntime});
// Create call to 'deinitVulkan' runtime function.
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
builder.create<LLVM::CallOp>(loc, TypeRange(),
builder.getSymbolRefAttr(kDeinitVulkan),
ValueRange{vulkanRuntime});

View File

@ -815,6 +815,19 @@ static LogicalResult verify(CallOp &op) {
<< ": " << op.getOperand(i + isIndirect).getType()
<< " != " << funcType.getParamType(i);
if (op.getNumResults() == 0 &&
!funcType.getReturnType().isa<LLVM::LLVMVoidType>())
return op.emitOpError() << "expected function call to produce a value";
if (op.getNumResults() != 0 &&
funcType.getReturnType().isa<LLVM::LLVMVoidType>())
return op.emitOpError()
<< "calling function with void result must not produce values";
if (op.getNumResults() > 1)
return op.emitOpError()
<< "expected LLVM function call to produce 0 or 1 result";
if (op.getNumResults() &&
op.getResult(0).getType() != funcType.getReturnType())
return op.emitOpError()
@ -874,19 +887,18 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
auto funcType = type.dyn_cast<FunctionType>();
if (!funcType)
return parser.emitError(trailingTypeLoc, "expected function type");
if (funcType.getNumResults() > 1)
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
if (isDirect) {
// Make sure types match.
if (parser.resolveOperands(operands, funcType.getInputs(),
parser.getNameLoc(), result.operands))
return failure();
result.addTypes(funcType.getResults());
if (funcType.getNumResults() != 0 &&
!funcType.getResult(0).isa<LLVM::LLVMVoidType>())
result.addTypes(funcType.getResults());
} else {
// Construct the LLVM IR Dialect function type that the first operand
// should match.
if (funcType.getNumResults() > 1)
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");
Builder &builder = parser.getBuilder();
Type llvmResultType;
if (funcType.getNumResults() == 0) {
@ -921,7 +933,8 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(), result.operands))
return failure();
result.addTypes(llvmResultType);
if (!llvmResultType.isa<LLVM::LLVMVoidType>())
result.addTypes(llvmResultType);
}
return success();

View File

@ -6,14 +6,14 @@
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> !llvm.void
// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> !llvm.void
// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> !llvm.void
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)

View File

@ -1089,3 +1089,33 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
llvm.return
}
// -----
llvm.func @caller() {
// expected-error @below {{expected function call to produce a value}}
llvm.call @callee() : () -> ()
llvm.return
}
llvm.func @callee() -> i32
// -----
llvm.func @caller() {
// expected-error @below {{calling function with void result must not produce values}}
%0 = llvm.call @callee() : () -> i32
llvm.return
}
llvm.func @callee() -> ()
// -----
llvm.func @caller() {
// expected-error @below {{expected function with 0 or 1 result}}
%0:2 = llvm.call @callee() : () -> (i32, f32)
llvm.return
}
llvm.func @callee() -> !llvm.struct<(i32, f32)>