[mlir][tosa] Tosa shape propagation for tosa.cond_if

We can propagate the shape from tosa.cond_if operands into the true/false
regions then through the connected blocks. Then, using the tosa.yield ops
we can determine what all possible return types are.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105940
This commit is contained in:
Rob Suderman 2021-08-03 17:25:43 -07:00
parent b4121b335c
commit 1b00b94ffc
5 changed files with 349 additions and 127 deletions

View File

@ -1789,6 +1789,8 @@ def Tosa_CustomOp : Tosa_Op<"custom"> {
// Further described in docs/Rationale/RationaleTOSADialect.md .
//===----------------------------------------------------------------------===//
def Tosa_IfOp : Tosa_Op<"cond_if", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
SingleBlockImplicitTerminator<"YieldOp">,
RecursiveSideEffects]> {
let summary = "Conditional if operator";

View File

@ -0,0 +1,178 @@
//===-- ShapeUtils.h - TOSA shape support declarations ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Class declarations for shape utilities meant to assist shape propagation.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
#define MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace tosa {
/// Statically known information for a particular Value.
///
/// This struct currently tracks only information relevant for tensor/array-like
/// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
/// type as long as it is in the default "no knowledge" state returned by
/// `getPessimisticValueState`. The important invariant is that we cannot
/// claim to know something about a value which is false.
///
/// This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge {
ValueKnowledge() = delete;
ValueKnowledge(bool hasRank, llvm::ArrayRef<int64_t> newSizes, Type dtype)
: hasError(false), hasRank(hasRank), dtype(dtype) {
sizes.reserve(newSizes.size());
for (auto size : newSizes)
sizes.push_back(size);
}
operator bool() const { return !hasError; }
// Get the static knowledge intrinsic to `type`.
static ValueKnowledge getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState();
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (shapedType.hasRank()) {
result.hasRank = true;
result.sizes.reserve(shapedType.getRank());
for (auto dim : shapedType.getShape())
result.sizes.push_back(dim);
}
result.dtype = shapedType.getElementType();
}
return result;
}
// Return a pessimistic/conservative value state without assuming any knowlege
// about the IR.
static ValueKnowledge getPessimisticValueState() {
return ValueKnowledge(false, {}, Type());
}
Type getType() const {
if (hasRank)
return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
return UnrankedTensorType::get(dtype);
}
bool operator==(const ValueKnowledge &rhs) const {
return hasRank == rhs.hasRank && sizes == rhs.sizes && dtype == rhs.dtype;
}
// Given two pieces of static knowledge, calculate conservatively the
// information we can be sure about.
static ValueKnowledge join(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) {
// Mental model: All conditions are checking how to change from the safe "no
// knowledge" default-initialized state to a state with more knowledge
// consistent with lhs and rhs.
ValueKnowledge result = getPessimisticValueState();
result.hasError = true;
if (!lhs || !rhs || lhs.dtype != rhs.dtype)
return result;
result.hasError = false;
result.dtype = lhs.dtype;
if (!lhs.hasRank && !rhs.hasRank)
return result;
if (!rhs.hasRank) {
result.hasRank = true;
result.sizes = lhs.sizes;
return result;
}
if (!lhs.hasRank) {
result.hasRank = true;
result.sizes = rhs.sizes;
return result;
}
if (lhs.sizes.size() != rhs.sizes.size())
return result;
result.hasRank = true;
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
for (auto i : llvm::seq<unsigned>(0, result.sizes.size())) {
int64_t lhsSize = lhs.sizes[i];
int64_t rhsSize = rhs.sizes[i];
int64_t &resultSize = result.sizes[i];
if (lhsSize == ShapedType::kDynamicSize) {
resultSize = rhsSize;
} else if (rhsSize == ShapedType::kDynamicSize) {
resultSize = lhsSize;
} else if (lhsSize == rhsSize) {
resultSize = lhsSize;
} else {
result.hasError = true;
}
}
return result;
}
// Given to types, generate a new ValueKnowledge that meets to cover both
// cases. E.g. if the rank of the LHS and RHS differ, the resulting tensor
// has unknown rank.
static ValueKnowledge meet(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) {
ValueKnowledge result = getPessimisticValueState();
result.hasError = true;
if (!rhs || !rhs || lhs.dtype != rhs.dtype)
return result;
result.hasError = false;
result.dtype = lhs.dtype;
if (!lhs.hasRank || !rhs.hasRank) {
result.hasRank = false;
return result;
}
if (lhs.sizes.size() != rhs.sizes.size()) {
result.hasRank = false;
return result;
}
result.hasRank = true;
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
for (int i = 0, e = lhs.sizes.size(); i < e; i++) {
if (lhs.sizes[i] == rhs.sizes[i]) {
result.sizes[i] = lhs.sizes[i];
}
}
return result;
}
// Whether the value information has an error.
bool hasError;
// Whether the value has known rank.
bool hasRank;
// If `hasRank`, the sizes along each rank. Unknown sizes are represented as
// `ShapedType::kDynamicSize`.
llvm::SmallVector<int64_t> sizes;
// The dtype of a tensor.
// This is equal to nullptr if we don't know that it is a specific concrete
// type.
Type dtype;
};
} // namespace tosa
} // namespace mlir
#endif // MLIR_DIALECT_TOSA_UTILS_SHAPEUTILS_H

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@ -1301,6 +1302,54 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
return success();
}
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (Region *region : regions) {
for (auto &block : *region)
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
}
if (yieldOps.empty())
return failure();
// Get the initial type information for the yield op.
llvm::SmallVector<ValueKnowledge> resultKnowledge;
resultKnowledge.reserve(yieldOps.front().getNumOperands());
for (auto operand : yieldOps.front().getOperands()) {
resultKnowledge.push_back(
ValueKnowledge::getKnowledgeFromType(operand.getType()));
}
for (auto yieldOp : yieldOps) {
if (resultKnowledge.size() != yieldOp.getNumOperands())
return failure();
for (auto it : llvm::enumerate(yieldOp.getOperands())) {
int32_t index = it.index();
auto meet = ValueKnowledge::meet(
resultKnowledge[index],
ValueKnowledge::getKnowledgeFromType(it.value().getType()));
if (!meet)
continue;
resultKnowledge[index] = meet;
}
}
for (auto result : resultKnowledge) {
if (result.hasRank) {
inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
} else {
inferredReturnShapes.push_back(ShapedTypeComponents());
}
}
return success();
}
//===----------------------------------------------------------------------===//
// TOSA Operator Definitions.
//===----------------------------------------------------------------------===//

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -30,137 +31,57 @@ using namespace mlir::tosa;
namespace {
// -----------------------------------------------------------------------------
// Analysis.
// -----------------------------------------------------------------------------
void propagateShapesInRegion(Region &region);
static Type joinElementTypes(Type lhs, Type rhs) {
return lhs == rhs ? lhs : Type();
void propagateShapesToTosaIf(Operation &op) {
tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
if (!ifOp)
return;
for (auto &region : op.getRegions()) {
Block &frontBlock = region.front();
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
return;
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
ifOp.getOperand(i + 1).getType());
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
frontBlock.getArgument(i).getType());
ValueKnowledge joinedKnowledge =
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
frontBlock.getArgument(i).setType(joinedKnowledge.getType());
}
propagateShapesInRegion(region);
}
return;
}
namespace {
// Statically known information for a particular Value.
//
// This struct currently tracks only information relevant for tensor/array-like
// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
// type as long as it is in the default "no knowledge" state returned by
// `getPessimisticValueState`. The important invariant is that we cannot
// claim to know something about a value which is false.
//
// This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge {
ValueKnowledge() = delete;
ValueKnowledge(bool hasSizes, std::vector<int64_t> sizes, Type dtype)
: hasSizes(hasSizes), sizes(sizes), dtype(dtype) {
assert(sizes.size() == 0 || hasSizes);
}
// Get the static knowledge intrinsic to `type`.
static ValueKnowledge getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState(type.getContext());
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (shapedType.hasRank()) {
result.hasSizes = true;
result.sizes = shapedType.getShape();
}
result.dtype = shapedType.getElementType();
}
return result;
}
// Return a pessimistic/conservative value state without assuming any knowlege
// about the IR.
static ValueKnowledge getPessimisticValueState(MLIRContext *context) {
return ValueKnowledge(false, {}, Type());
}
Type getType() const {
if (hasSizes) {
return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);
}
return UnrankedTensorType::get(dtype);
}
bool operator==(const ValueKnowledge &rhs) const {
return std::make_tuple(hasSizes, sizes, dtype) ==
std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype);
}
// Given two pieces of static knowledge, calculate conservatively the
// information we can be sure about.
static ValueKnowledge join(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) {
// Mental model: All conditions are checking how to change from the safe "no
// knowledge" default-initialized state to a state with more knowledge
// consistent with lhs and rhs.
ValueKnowledge result = getPessimisticValueState(nullptr);
if (lhs.hasSizes && !rhs.hasSizes) {
result.hasSizes = true;
result.sizes = lhs.sizes;
} else if (!lhs.hasSizes && rhs.hasSizes) {
result.hasSizes = true;
result.sizes = rhs.sizes;
} else if (lhs.hasSizes && rhs.hasSizes &&
lhs.sizes.size() == rhs.sizes.size()) {
result.hasSizes = true;
result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize);
for (int i = 0, e = result.sizes.size(); i != e; i++) {
int64_t lhsSize = lhs.sizes[i];
int64_t rhsSize = rhs.sizes[i];
int64_t &resultSize = result.sizes[i];
if (lhsSize == ShapedType::kDynamicSize) {
resultSize = rhsSize;
} else if (rhsSize == ShapedType::kDynamicSize) {
resultSize = lhsSize;
} else if (lhsSize == rhsSize) {
resultSize = lhsSize;
}
}
}
result.dtype = joinElementTypes(lhs.dtype, rhs.dtype);
return result;
}
// Whether the Value is known to have a list of sizes.
bool hasSizes;
// If `hasSizes`, the sizes along each rank. Unknown sizes are represented as
// `ShapedType::kDynamicSize`.
std::vector<int64_t> sizes;
// The dtype of a tensor.
// This is equal to nullptr if we don't know that it is a specific concrete
// type.
Type dtype;
};
} // namespace
/// Pass that enables broadcast by making all input arrays have the same
/// number of dimensions. Insert RESHAPE operations to lower rank operand
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
public:
void runOnFunction() override {
FuncOp func = getOperation();
IRRewriter rewriter(func.getContext());
func.walk([&](Operation *op) {
if (op->getDialect()->getNamespace() !=
void propagateShapesInRegion(Region &region) {
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect()->getNamespace() !=
tosa::TosaDialect::getDialectNamespace())
return;
continue;
propagateShapesToTosaIf(op);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
return;
continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getRegions(), returnedShapes)
op.getContext(), op.getLoc(), op.getOperands(),
op.getAttrDictionary(), op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op->getResults(), returnedShapes)) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
@ -183,11 +104,10 @@ public:
ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type.
auto inferredKnowledge =
ValueKnowledge::getPessimisticValueState(op->getContext());
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype =
resultTy.cast<ShapedType>().getElementType();
inferredKnowledge.hasSizes = predictedShape.hasRank();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
@ -200,10 +120,25 @@ public:
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
result.setType(newKnowledge.getType());
}
}
});
}
}
}
/// Pass that performs shape propagation across TOSA operations. This includes
/// migrating to within the regions of if/while operations.
struct TosaInferShapes : public TosaInferShapesBase<TosaInferShapes> {
public:
void runOnFunction() override {
FuncOp func = getOperation();
IRRewriter rewriter(func.getContext());
propagateShapesInRegion(func.body());
// Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
// the FuncOp type.

View File

@ -774,7 +774,6 @@ func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32
// -----
// CHECK-LABEL: @conv2d_strided
func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () {
// CHECK: -> tensor<1x5x7x1xf32>
@ -1033,12 +1032,71 @@ func @resize_fp_vertical(%arg0: tensor<1x2x4x1xi32>) {
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [0.000000e+00 : f32, 0.000000e+00 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [5.000000e-01 : f32, 1.000000e+00 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
return
}
// -----
// CHECK-LABEL: @resize_fp_offsetted
func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
// CHECK: -> tensor<1x4x6x1xi32>
%0 = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", offset = [0, 0], offset_fp = [2.500000e-01 : f32, 2.500000e-01 : f32], output_size = [-1, -1], shift = 0 : i32, stride = [0, 0], stride_fp = [2.500000e-01 : f32, 5.000000e-01 : f32]} : (tensor<1x2x4x1xi32>) -> tensor<?x?x?x?xi32>
return
}
// -----
// CHECK-LABEL: @if_test_simple
func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
"tosa.yield"(%arg3) : (tensor<f32>) -> ()
}, {
^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
"tosa.yield"(%arg6) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
return
}
// -----
// CHECK-LABEL: @if_test_dynamic
func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> tensor<?xf32>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb1(%arg3 : tensor<2xf32>, %arg4 : tensor<3xf32>):
"tosa.yield"(%arg3) : (tensor<2xf32>) -> ()
}, {
^bb1(%arg5 : tensor<2xf32>, %arg6 : tensor<3xf32>):
"tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
}) : (tensor<i1>, tensor<2xf32>, tensor<3xf32>) -> (tensor<*xf32>)
return
}
// -----
// CHECK-LABEL: @if_test_unranked
func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: (tensor<i1>, tensor<f32>, tensor<3xf32>) -> tensor<*xf32>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb1(%arg3 : tensor<f32>, %arg4 : tensor<3xf32>):
"tosa.yield"(%arg3) : (tensor<f32>) -> ()
}, {
^bb1(%arg5 : tensor<f32>, %arg6 : tensor<3xf32>):
"tosa.yield"(%arg6) : (tensor<3xf32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<3xf32>) -> (tensor<*xf32>)
return
}
// -----
// CHECK-LABEL: @if_test_propagate
func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
// CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
%1 = "tosa.add"(%arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"tosa.yield"(%1) : (tensor<*xf32>) -> ()
}, {
^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
%1 = "tosa.sub"(%arg5, %arg6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
"tosa.yield"(%1) : (tensor<*xf32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
return
}