forked from OSchip/llvm-project
[mlir] harden result type verification in llvm.call
The verifier of the llvm.call operation was not checking for mismatches between the number of operation results and the number of results in the signature of the callee. Furthermore, it was possible to construct an llvm.call operation producing an SSA value of !llvm.void type, which should not exist. Add the verification and treat !llvm.void result type as absence of call results. Update the GPU conversions to LLVM that were mistakenly assuming that it was fine for llvm.call to produce values of !llvm.void type and ensure these calls do not produce results. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D106937
This commit is contained in:
parent
6e8660a7d6
commit
c1f719d1a7
|
@ -353,9 +353,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
|
|||
return OpBuilder::atBlockEnd(module.getBody())
|
||||
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
|
||||
}();
|
||||
return builder.create<LLVM::CallOp>(
|
||||
loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
|
||||
builder.getSymbolRefAttr(function), arguments);
|
||||
return builder.create<LLVM::CallOp>(loc, function, arguments);
|
||||
}
|
||||
|
||||
// Returns whether all operands are of LLVM type.
|
||||
|
|
|
@ -248,7 +248,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
|
|||
}
|
||||
// Create call to `bindMemRef`.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange{getVoidType()},
|
||||
loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(
|
||||
StringRef(symbolName.data(), symbolName.size())),
|
||||
ValueRange{vulkanRuntime, descriptorSet, descriptorBinding,
|
||||
|
@ -396,32 +396,31 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
|
|||
// Create call to `setBinaryShader` runtime function with the given pointer to
|
||||
// SPIR-V binary and binary size.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange{getVoidType()}, builder.getSymbolRefAttr(kSetBinaryShader),
|
||||
loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader),
|
||||
ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
|
||||
// Create LLVM global with entry point name.
|
||||
Value entryPointName = createEntryPointNameConstant(
|
||||
spirvAttributes.second.getValue(), loc, builder);
|
||||
// Create call to `setEntryPoint` runtime function with the given pointer to
|
||||
// entry point name.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kSetEntryPoint),
|
||||
ValueRange{vulkanRuntime, entryPointName});
|
||||
|
||||
// Create number of local workgroup for each dimension.
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, TypeRange{getVoidType()},
|
||||
builder.getSymbolRefAttr(kSetNumWorkGroups),
|
||||
loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups),
|
||||
ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0),
|
||||
cInterfaceVulkanLaunchCallOp.getOperand(1),
|
||||
cInterfaceVulkanLaunchCallOp.getOperand(2)});
|
||||
|
||||
// Create call to `runOnVulkan` runtime function.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kRunOnVulkan),
|
||||
ValueRange{vulkanRuntime});
|
||||
|
||||
// Create call to 'deinitVulkan' runtime function.
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange{getVoidType()},
|
||||
builder.create<LLVM::CallOp>(loc, TypeRange(),
|
||||
builder.getSymbolRefAttr(kDeinitVulkan),
|
||||
ValueRange{vulkanRuntime});
|
||||
|
||||
|
|
|
@ -815,6 +815,19 @@ static LogicalResult verify(CallOp &op) {
|
|||
<< ": " << op.getOperand(i + isIndirect).getType()
|
||||
<< " != " << funcType.getParamType(i);
|
||||
|
||||
if (op.getNumResults() == 0 &&
|
||||
!funcType.getReturnType().isa<LLVM::LLVMVoidType>())
|
||||
return op.emitOpError() << "expected function call to produce a value";
|
||||
|
||||
if (op.getNumResults() != 0 &&
|
||||
funcType.getReturnType().isa<LLVM::LLVMVoidType>())
|
||||
return op.emitOpError()
|
||||
<< "calling function with void result must not produce values";
|
||||
|
||||
if (op.getNumResults() > 1)
|
||||
return op.emitOpError()
|
||||
<< "expected LLVM function call to produce 0 or 1 result";
|
||||
|
||||
if (op.getNumResults() &&
|
||||
op.getResult(0).getType() != funcType.getReturnType())
|
||||
return op.emitOpError()
|
||||
|
@ -874,19 +887,18 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
|||
auto funcType = type.dyn_cast<FunctionType>();
|
||||
if (!funcType)
|
||||
return parser.emitError(trailingTypeLoc, "expected function type");
|
||||
if (funcType.getNumResults() > 1)
|
||||
return parser.emitError(trailingTypeLoc,
|
||||
"expected function with 0 or 1 result");
|
||||
if (isDirect) {
|
||||
// Make sure types match.
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(),
|
||||
parser.getNameLoc(), result.operands))
|
||||
return failure();
|
||||
if (funcType.getNumResults() != 0 &&
|
||||
!funcType.getResult(0).isa<LLVM::LLVMVoidType>())
|
||||
result.addTypes(funcType.getResults());
|
||||
} else {
|
||||
// Construct the LLVM IR Dialect function type that the first operand
|
||||
// should match.
|
||||
if (funcType.getNumResults() > 1)
|
||||
return parser.emitError(trailingTypeLoc,
|
||||
"expected function with 0 or 1 result");
|
||||
|
||||
Builder &builder = parser.getBuilder();
|
||||
Type llvmResultType;
|
||||
if (funcType.getNumResults() == 0) {
|
||||
|
@ -921,6 +933,7 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.getNameLoc(), result.operands))
|
||||
return failure();
|
||||
|
||||
if (!llvmResultType.isa<LLVM::LLVMVoidType>())
|
||||
result.addTypes(llvmResultType);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,14 +6,14 @@
|
|||
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
|
||||
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
|
||||
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
|
||||
// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> !llvm.void
|
||||
// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> !llvm.void
|
||||
// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
|
||||
// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
|
||||
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
|
||||
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
|
||||
// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> !llvm.void
|
||||
// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> !llvm.void
|
||||
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
|
||||
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
|
||||
// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
|
||||
// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
|
||||
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
|
||||
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
|
||||
|
||||
// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
|
||||
|
||||
|
|
|
@ -1089,3 +1089,33 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
|
|||
%0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @caller() {
|
||||
// expected-error @below {{expected function call to produce a value}}
|
||||
llvm.call @callee() : () -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @callee() -> i32
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @caller() {
|
||||
// expected-error @below {{calling function with void result must not produce values}}
|
||||
%0 = llvm.call @callee() : () -> i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @callee() -> ()
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @caller() {
|
||||
// expected-error @below {{expected function with 0 or 1 result}}
|
||||
%0:2 = llvm.call @callee() : () -> (i32, f32)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @callee() -> !llvm.struct<(i32, f32)>
|
||||
|
|
Loading…
Reference in New Issue