forked from OSchip/llvm-project
[mlir][tosa] Tosa shape propagation for tosa.cond_if
We can propagate the shape from tosa.cond_if operands into the true/false regions then through the connected blocks. Then, using the tosa.yield ops we can determine what all possible return types are. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D105940
This commit is contained in:
parent
b4121b335c
commit
1b00b94ffc
|
@ -1789,6 +1789,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
|
|||
// Further described in docs/Rationale/RationaleTOSADialect.md .
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_IfOp : Tosa_Op<"cond_if", [
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
let summary = "Conditional if operator";
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Class declarations for shape utilities meant to assist shape propagation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
|
||||
#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
/// Statically known information for a particular Value.
|
||||
///
|
||||
/// This struct currently tracks only information relevant for tensor/array-like
|
||||
/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
|
||||
/// type as long as it is in the default "no knowledge" state returned by
|
||||
/// `getPessimisticValueState`. The important invariant is that we cannot
|
||||
/// claim to know something about a value which is false.
|
||||
///
|
||||
/// This class could also be called "dataflow facts", "lattice value", etc.
|
||||
struct ValueKnowledge {
|
||||
ValueKnowledge() = delete;
|
||||
ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
|
||||
: hasError(false), hasRank(hasRank), dtype(dtype) {
|
||||
sizes.reserve(newSizes.size());
|
||||
for (auto size : newSizes)
|
||||
sizes.push_back(size);
|
||||
}
|
||||
|
||||
operator bool() const { return !hasError; }
|
||||
|
||||
// Get the static knowledge intrinsic to `type`.
|
||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||
ValueKnowledge result = getPessimisticValueState();
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
if (shapedType.hasRank()) {
|
||||
result.hasRank = true;
|
||||
result.sizes.reserve(shapedType.getRank());
|
||||
for (auto dim : shapedType.getShape())
|
||||
result.sizes.push_back(dim);
|
||||
}
|
||||
result.dtype = shapedType.getElementType();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return a pessimistic/conservative value state without assuming any knowlege
|
||||
// about the IR.
|
||||
static ValueKnowledge getPessimisticValueState() {
|
||||
return ValueKnowledge(false, {}, Type());
|
||||
}
|
||||
|
||||
Type getType() const {
|
||||
if (hasRank)
|
||||
return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
|
||||
return UnrankedTensorType::get(dtype);
|
||||
}
|
||||
|
||||
bool operator==(const ValueKnowledge &rhs) const {
|
||||
return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, calculate conservatively the
|
||||
// information we can be sure about.
|
||||
static ValueKnowledge join(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
// Mental model: All conditions are checking how to change from the safe "no
|
||||
// knowledge" default-initialized state to a state with more knowledge
|
||||
// consistent with lhs and rhs.
|
||||
ValueKnowledge result = getPessimisticValueState();
|
||||
result.hasError = true;
|
||||
|
||||
if (!lhs || !rhs || lhs.dtype != rhs.dtype)
|
||||
return result;
|
||||
|
||||
result.hasError = false;
|
||||
result.dtype = lhs.dtype;
|
||||
|
||||
if (!lhs.hasRank && !rhs.hasRank)
|
||||
return result;
|
||||
|
||||
if (!rhs.hasRank) {
|
||||
result.hasRank = true;
|
||||
result.sizes = lhs.sizes;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (!lhs.hasRank) {
|
||||
result.hasRank = true;
|
||||
result.sizes = rhs.sizes;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (lhs.sizes.size() != rhs.sizes.size())
|
||||
return result;
|
||||
|
||||
result.hasRank = true;
|
||||
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
|
||||
for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
|
||||
int64_t lhsSize = lhs.sizes[i];
|
||||
int64_t rhsSize = rhs.sizes[i];
|
||||
int64_t &resultSize = result.sizes[i];
|
||||
if (lhsSize == ShapedType::kDynamicSize) {
|
||||
resultSize = rhsSize;
|
||||
} else if (rhsSize == ShapedType::kDynamicSize) {
|
||||
resultSize = lhsSize;
|
||||
} else if (lhsSize == rhsSize) {
|
||||
resultSize = lhsSize;
|
||||
} else {
|
||||
result.hasError = true;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Given to types, generate a new ValueKnowledge that meets to cover both
|
||||
// cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor
|
||||
// has unknown rank.
|
||||
static ValueKnowledge meet(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
ValueKnowledge result = getPessimisticValueState();
|
||||
result.hasError = true;
|
||||
|
||||
if (!rhs || !rhs || lhs.dtype != rhs.dtype)
|
||||
return result;
|
||||
|
||||
result.hasError = false;
|
||||
result.dtype = lhs.dtype;
|
||||
|
||||
if (!lhs.hasRank || !rhs.hasRank) {
|
||||
result.hasRank = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (lhs.sizes.size() != rhs.sizes.size()) {
|
||||
result.hasRank = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
result.hasRank = true;
|
||||
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
|
||||
for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
|
||||
if (lhs.sizes[i] == rhs.sizes[i]) {
|
||||
result.sizes[i] = lhs.sizes[i];
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Whether the value information has an error.
|
||||
bool hasError;
|
||||
// Whether the value has known rank.
|
||||
bool hasRank;
|
||||
// If `hasRank`, the sizes along each rank. Unknown sizes are represented as
|
||||
// `ShapedType::kDynamicSize`.
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
// The dtype of a tensor.
|
||||
// This is equal to nullptr if we don't know that it is a specific concrete
|
||||
// type.
|
||||
Type dtype;
|
||||
};
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -1301,6 +1302,54 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult IfOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, ::llvm::Optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
llvm::SmallVector<tosa::YieldOp> yieldOps;
|
||||
for (Region *region : regions) {
|
||||
for (auto &block : *region)
|
||||
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
|
||||
yieldOps.push_back(returnOp);
|
||||
}
|
||||
|
||||
if (yieldOps.empty())
|
||||
return failure();
|
||||
|
||||
// Get the initial type information for the yield op.
|
||||
llvm::SmallVector<ValueKnowledge> resultKnowledge;
|
||||
resultKnowledge.reserve(yieldOps.front().getNumOperands());
|
||||
for (auto operand : yieldOps.front().getOperands()) {
|
||||
resultKnowledge.push_back(
|
||||
ValueKnowledge::getKnowledgeFromType(operand.getType()));
|
||||
}
|
||||
|
||||
for (auto yieldOp : yieldOps) {
|
||||
if (resultKnowledge.size() != yieldOp.getNumOperands())
|
||||
return failure();
|
||||
|
||||
for (auto it : llvm::enumerate(yieldOp.getOperands())) {
|
||||
int32_t index = it.index();
|
||||
auto meet = ValueKnowledge::meet(
|
||||
resultKnowledge[index],
|
||||
ValueKnowledge::getKnowledgeFromType(it.value().getType()));
|
||||
if (!meet)
|
||||
continue;
|
||||
resultKnowledge[index] = meet;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto result : resultKnowledge) {
|
||||
if (result.hasRank) {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
|
||||
} else {
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Operator Definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -30,137 +31,57 @@ using namespace mlir::tosa;
|
|||
|
||||
namespace {
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Analysis.
|
||||
// -----------------------------------------------------------------------------
|
||||
void propagateShapesInRegion(Region ®ion);
|
||||
|
||||
static Type joinElementTypes(Type lhs, Type rhs) {
|
||||
return lhs == rhs ? lhs : Type();
|
||||
void propagateShapesToTosaIf(Operation &op) {
|
||||
tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
|
||||
if (!ifOp)
|
||||
return;
|
||||
|
||||
for (auto ®ion : op.getRegions()) {
|
||||
Block &frontBlock = region.front();
|
||||
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
|
||||
return;
|
||||
|
||||
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
|
||||
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
ifOp.getOperand(i + 1).getType());
|
||||
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
frontBlock.getArgument(i).getType());
|
||||
ValueKnowledge joinedKnowledge =
|
||||
ValueKnowledge::join(operandKnowledge, blockKnowledge);
|
||||
if (!joinedKnowledge)
|
||||
continue;
|
||||
frontBlock.getArgument(i).setType(joinedKnowledge.getType());
|
||||
}
|
||||
|
||||
propagateShapesInRegion(region);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Statically known information for a particular Value.
|
||||
//
|
||||
// This struct currently tracks only information relevant for tensor/array-like
|
||||
// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
|
||||
// type as long as it is in the default "no knowledge" state returned by
|
||||
// `getPessimisticValueState`. The important invariant is that we cannot
|
||||
// claim to know something about a value which is false.
|
||||
//
|
||||
// This class could also be called "dataflow facts", "lattice value", etc.
|
||||
struct ValueKnowledge {
|
||||
ValueKnowledge() = delete;
|
||||
ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
|
||||
: hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
|
||||
assert(sizes.size() == 0 || hasSizes);
|
||||
}
|
||||
|
||||
// Get the static knowledge intrinsic to `type`.
|
||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
if (shapedType.hasRank()) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = shapedType.getShape();
|
||||
}
|
||||
result.dtype = shapedType.getElementType();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Return a pessimistic/conservative value state without assuming any knowlege
|
||||
// about the IR.
|
||||
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
|
||||
return ValueKnowledge(false, {}, Type());
|
||||
}
|
||||
|
||||
Type getType() const {
|
||||
if (hasSizes) {
|
||||
return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
|
||||
}
|
||||
return UnrankedTensorType::get(dtype);
|
||||
}
|
||||
|
||||
bool operator==(const ValueKnowledge &rhs) const {
|
||||
return std::make_tuple(hasSizes, sizes, dtype) ==
|
||||
std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
|
||||
}
|
||||
|
||||
// Given two pieces of static knowledge, calculate conservatively the
|
||||
// information we can be sure about.
|
||||
static ValueKnowledge join(const ValueKnowledge &lhs,
|
||||
const ValueKnowledge &rhs) {
|
||||
// Mental model: All conditions are checking how to change from the safe "no
|
||||
// knowledge" default-initialized state to a state with more knowledge
|
||||
// consistent with lhs and rhs.
|
||||
ValueKnowledge result = getPessimisticValueState(nullptr);
|
||||
|
||||
if (lhs.hasSizes && !rhs.hasSizes) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = lhs.sizes;
|
||||
} else if (!lhs.hasSizes && rhs.hasSizes) {
|
||||
result.hasSizes = true;
|
||||
result.sizes = rhs.sizes;
|
||||
} else if (lhs.hasSizes && rhs.hasSizes &&
|
||||
lhs.sizes.size() == rhs.sizes.size()) {
|
||||
result.hasSizes = true;
|
||||
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
|
||||
for (int i = 0, e = result.sizes.size(); i != e; i++) {
|
||||
int64_t lhsSize = lhs.sizes[i];
|
||||
int64_t rhsSize = rhs.sizes[i];
|
||||
int64_t &resultSize = result.sizes[i];
|
||||
if (lhsSize == ShapedType::kDynamicSize) {
|
||||
resultSize = rhsSize;
|
||||
} else if (rhsSize == ShapedType::kDynamicSize) {
|
||||
resultSize = lhsSize;
|
||||
} else if (lhsSize == rhsSize) {
|
||||
resultSize = lhsSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Whether the Value is known to have a list of sizes.
|
||||
bool hasSizes;
|
||||
// If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
|
||||
// `ShapedType::kDynamicSize`.
|
||||
std::vector<int64_t> sizes;
|
||||
// The dtype of a tensor.
|
||||
// This is equal to nullptr if we don't know that it is a specific concrete
|
||||
// type.
|
||||
Type dtype;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Pass that enables broadcast by making all input arrays have the same
|
||||
/// number of dimensions. Insert RESHAPE operations to lower rank operand
|
||||
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getOperation();
|
||||
|
||||
IRRewriter rewriter(func.getContext());
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->getDialect()->getNamespace() !=
|
||||
void propagateShapesInRegion(Region ®ion) {
|
||||
for (auto &block : region) {
|
||||
for (Operation &op : block) {
|
||||
if (op.getDialect()->getNamespace() !=
|
||||
tosa::TosaDialect::getDialectNamespace())
|
||||
return;
|
||||
continue;
|
||||
|
||||
propagateShapesToTosaIf(op);
|
||||
|
||||
InferShapedTypeOpInterface shapeInterface =
|
||||
dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shapeInterface)
|
||||
return;
|
||||
continue;
|
||||
|
||||
SmallVector<ShapedTypeComponents> returnedShapes;
|
||||
if (shapeInterface
|
||||
.inferReturnTypeComponents(
|
||||
op->getContext(), op->getLoc(), op->getOperands(),
|
||||
op->getAttrDictionary(), op->getRegions(), returnedShapes)
|
||||
op.getContext(), op.getLoc(), op.getOperands(),
|
||||
op.getAttrDictionary(), op.getRegions(), returnedShapes)
|
||||
.succeeded()) {
|
||||
for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
|
||||
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
|
||||
Value result = std::get<0>(it);
|
||||
ShapedTypeComponents predictedShape = std::get<1>(it);
|
||||
|
||||
|
@ -183,11 +104,10 @@ public:
|
|||
ValueKnowledge::getKnowledgeFromType(resultTy);
|
||||
|
||||
// Compute the knowledge based on the inferred type.
|
||||
auto inferredKnowledge =
|
||||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
||||
inferredKnowledge.dtype =
|
||||
resultTy.cast<ShapedType>().getElementType();
|
||||
inferredKnowledge.hasSizes = predictedShape.hasRank();
|
||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||
if (predictedShape.hasRank()) {
|
||||
for (auto dim : predictedShape.getDims()) {
|
||||
inferredKnowledge.sizes.push_back(dim);
|
||||
|
@ -200,10 +120,25 @@ public:
|
|||
// Compute the new type based on the joined version.
|
||||
auto newKnowledge =
|
||||
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
||||
if (!newKnowledge)
|
||||
continue;
|
||||
result.setType(newKnowledge.getType());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pass that performs shape propagation across TOSA operations. This includes
|
||||
/// migrating to within the regions of if/while operations.
|
||||
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getOperation();
|
||||
|
||||
IRRewriter rewriter(func.getContext());
|
||||
|
||||
propagateShapesInRegion(func.body());
|
||||
|
||||
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
|
||||
// the FuncOp type.
|
||||
|
|
|
@ -774,7 +774,6 @@ func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32
|
|||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK-LABEL: @conv2d_strided
|
||||
func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
|
||||
// CHECK: -> tensor<1x5x7x1xf32>
|
||||
|
@ -1033,12 +1032,71 @@ func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
|
|||
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @resize_fp_offsetted
|
||||
func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
|
||||
// CHECK: -> tensor<1x4x6x1xi32>
|
||||
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @if_test_simple
|
||||
func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
|
||||
^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
|
||||
"tosa.yield"(%arg3) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
|
||||
"tosa.yield"(%arg6) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @if_test_dynamic
|
||||
func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
|
||||
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
|
||||
^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>):
|
||||
"tosa.yield"(%arg3) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>):
|
||||
"tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
|
||||
}) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @if_test_unranked
|
||||
func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
|
||||
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
|
||||
^bb1(%arg3 : tensor<f32>, %arg4 : tensor<3xf32>):
|
||||
"tosa.yield"(%arg3) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb1(%arg5 : tensor<f32>, %arg6 : tensor<3xf32>):
|
||||
"tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @if_test_propagate
|
||||
func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
|
||||
^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
|
||||
%1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
"tosa.yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}, {
|
||||
^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
|
||||
%1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
"tosa.yield"(%1) : (tensor<*xf32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue