[mlir][spirv] Convert functions returning one value

Reviewed By: hanchung, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D93468
This commit is contained in:
Lei Zhang 2020-12-23 13:21:57 -05:00
parent 7ad666798f
commit 42980a789d
3 changed files with 40 additions and 6 deletions

View File

@ -924,10 +924,14 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
LogicalResult
ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (returnOp.getNumOperands()) {
if (returnOp.getNumOperands() > 1)
return failure();
if (returnOp.getNumOperands() == 1) {
rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
} else {
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return success();
}

View File

@ -473,23 +473,27 @@ LogicalResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
// TODO: support converting functions with one result.
if (fnType.getNumResults())
if (fnType.getNumResults() > 1)
return failure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
for (auto argType : enumerate(funcOp.getType().getInputs())) {
for (auto argType : enumerate(fnType.getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
if (!convertedType)
return failure();
signatureConverter.addInputs(argType.index(), convertedType);
}
Type resultType;
if (fnType.getNumResults() == 1)
resultType = typeConverter.convertType(fnType.getResult(0));
// Create the converted spv.func op.
auto newFuncOp = rewriter.create<spirv::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
llvm::None));
resultType ? TypeRange(resultType)
: TypeRange()));
// Copy over all attributes other than the function name and type.
for (const auto &namedAttr : funcOp.getAttrs()) {

View File

@ -954,3 +954,29 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
}
} // end module
// -----
//===----------------------------------------------------------------------===//
// std.return
//===----------------------------------------------------------------------===//
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
// CHECK-LABEL: spv.func @return_one_val
// CHECK-SAME: (%[[ARG:.+]]: f32)
func @return_one_val(%arg0: f32) -> f32 {
// CHECK: spv.ReturnValue %[[ARG]] : f32
return %arg0: f32
}
// Check that multiple-return functions are not converted.
// CHECK-LABEL: func @return_multi_val
func @return_multi_val(%arg0: f32) -> (f32, f32) {
// CHECK: return
return %arg0, %arg0: f32, f32
}
}