forked from OSchip/llvm-project
[mlir][spirv] Convert functions returning one value
Reviewed By: hanchung, ThomasRaoux Differential Revision: https://reviews.llvm.org/D93468
This commit is contained in:
parent
7ad666798f
commit
42980a789d
|
@ -924,10 +924,14 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|||
LogicalResult
|
||||
ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (returnOp.getNumOperands()) {
|
||||
if (returnOp.getNumOperands() > 1)
|
||||
return failure();
|
||||
|
||||
if (returnOp.getNumOperands() == 1) {
|
||||
rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -473,23 +473,27 @@ LogicalResult
|
|||
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto fnType = funcOp.getType();
|
||||
// TODO: support converting functions with one result.
|
||||
if (fnType.getNumResults())
|
||||
if (fnType.getNumResults() > 1)
|
||||
return failure();
|
||||
|
||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||
for (auto argType : enumerate(funcOp.getType().getInputs())) {
|
||||
for (auto argType : enumerate(fnType.getInputs())) {
|
||||
auto convertedType = typeConverter.convertType(argType.value());
|
||||
if (!convertedType)
|
||||
return failure();
|
||||
signatureConverter.addInputs(argType.index(), convertedType);
|
||||
}
|
||||
|
||||
Type resultType;
|
||||
if (fnType.getNumResults() == 1)
|
||||
resultType = typeConverter.convertType(fnType.getResult(0));
|
||||
|
||||
// Create the converted spv.func op.
|
||||
auto newFuncOp = rewriter.create<spirv::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(),
|
||||
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
|
||||
llvm::None));
|
||||
resultType ? TypeRange(resultType)
|
||||
: TypeRange()));
|
||||
|
||||
// Copy over all attributes other than the function name and type.
|
||||
for (const auto &namedAttr : funcOp.getAttrs()) {
|
||||
|
|
|
@ -954,3 +954,29 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
|
|||
}
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// std.return
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: spv.func @return_one_val
|
||||
// CHECK-SAME: (%[[ARG:.+]]: f32)
|
||||
func @return_one_val(%arg0: f32) -> f32 {
|
||||
// CHECK: spv.ReturnValue %[[ARG]] : f32
|
||||
return %arg0: f32
|
||||
}
|
||||
|
||||
// Check that multiple-return functions are not converted.
|
||||
// CHECK-LABEL: func @return_multi_val
|
||||
func @return_multi_val(%arg0: f32) -> (f32, f32) {
|
||||
// CHECK: return
|
||||
return %arg0, %arg0: f32, f32
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue