!26357 [MS][LITE]optimize infershape to support switch_layer

Merge pull request !26357 from mengyuanli/infer_shape2
This commit is contained in:
i-robot 2021-11-16 08:19:40 +00:00 committed by Gitee
commit c784dec1a8
2 changed files with 39 additions and 45 deletions

View File

@ -32,8 +32,6 @@
#include "nnacl/op_base.h"
using mindspore::converter::kFmkTypeTf;
namespace mindspore {
namespace lite {
namespace {
constexpr int DEFAULT_DIM_VALUE = -1;
constexpr size_t kInitialSize = 1024;
@ -43,8 +41,11 @@ constexpr int kSwitchInputMinSize = 3;
constexpr int kTypeIndex = 0;
constexpr int kElementShapeIndex = 1;
constexpr int kFirstElementShapeIndex = 2;
constexpr int kTensorListDatasize = 3;
constexpr int kTensorListDataSize = 3;
} // namespace
namespace mindspore {
namespace lite {
namespace {
void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
if (input_tensors == nullptr) {
return;
@ -91,19 +92,19 @@ void ConvertTensorList(const MetaGraphT *graph, uint32_t index, bool *convert_su
int *data = reinterpret_cast<int *>(tensorT->data.data());
type = TypeId(data[kTypeIndex]);
auto basic_data_size = tensorT->data.size() / sizeof(int);
if (basic_data_size < static_cast<size_t>(kTensorListDatasize)) {
if (basic_data_size < static_cast<size_t>(kTensorListDataSize)) {
MS_LOG(ERROR) << "tensorlist data length illegal, which should be at least 3, now is " << basic_data_size;
*convert_succ = false;
return;
}
if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDatasize)) {
if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDataSize)) {
MS_LOG(ERROR) << "int add overflow.";
*convert_succ = false;
return;
}
if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDatasize)) > basic_data_size) {
if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDataSize)) > basic_data_size) {
MS_LOG(ERROR) << "tensorlist data length illegal. current tensorlist data length should be at least "
<< (data[kElementShapeIndex] + kTensorListDatasize) << ", but now is " << basic_data_size;
<< (data[kElementShapeIndex] + kTensorListDataSize) << ", but now is " << basic_data_size;
*convert_succ = false;
return;
}
@ -311,7 +312,7 @@ int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
if (!tensor_list->tensors().empty()) {
tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
}
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDatasize), static_cast<int>(sizeof(int))),
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDataSize), static_cast<int>(sizeof(int))),
RET_ERROR, "int mul overflow");
if (tensor_list->tensors_data_type() == kTypeUnknown) {
if (!tensor_list->tensors().empty()) {
@ -440,42 +441,37 @@ void InferShapePass::InitInferTensor(MetaGraphT *graph) {
}
}
int InferShapePass::InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph) {
if (switch_node->inputIndex.size() < kSwitchInputMinSize) {
MS_LOG(ERROR) << "switch node input size: " << switch_node->inputIndex.size() << " is less than three.";
int InferShapePass::InferSwitchOrSwitchLayerNode(const std::unique_ptr<CNodeT> &aim_node, MetaGraphT *graph) {
if (aim_node->inputIndex.size() < kSwitchInputMinSize) {
MS_LOG(ERROR) << "switch or switch_layer node input size: " << aim_node->inputIndex.size() << " is less than 3.";
return RET_PARAM_INVALID;
}
size_t aim_node_input_size = aim_node->inputIndex.size();
std::vector<uint32_t> all_partial_index{};
for (size_t i = 1; i < aim_node_input_size; ++i) {
all_partial_index.push_back(aim_node->inputIndex.at(i));
}
std::vector<CNodeT *> all_partial_nodes{};
for (auto &partial_index : all_partial_index) {
for (auto &node : graph->nodes) {
if (node->primitive->value.type != PrimitiveType_PartialFusion) {
continue;
}
if (IsContain(node->outputIndex, partial_index)) {
all_partial_nodes.push_back(node.get());
break;
}
}
}
std::deque<CNodeT *> to_process{};
auto true_branch_output_index = switch_node->inputIndex.at(kSwitchTrueIndex);
auto false_branch_output_index = switch_node->inputIndex.at(kSwitchFalseIndex);
bool find_true_partial = false;
bool find_false_partial = false;
CNodeT *true_partial_cnode = nullptr;
CNodeT *false_partial_cnode = nullptr;
for (auto &node : graph->nodes) {
if (node->primitive->value.type != PrimitiveType_PartialFusion) {
continue;
for (auto &partial_node : all_partial_nodes) {
if (partial_cnode_inferred_.find(partial_node) == partial_cnode_inferred_.end()) {
to_process.push_back(partial_node);
partial_cnode_inferred_.insert(partial_node);
}
if (!find_true_partial && IsContain(node->outputIndex, true_branch_output_index)) {
true_partial_cnode = node.get();
find_true_partial = true;
}
if (!find_false_partial && IsContain(node->outputIndex, false_branch_output_index)) {
false_partial_cnode = node.get();
find_false_partial = true;
}
if (find_true_partial && find_false_partial) {
break;
}
}
if (partial_cnode_inferred_.find(true_partial_cnode) == partial_cnode_inferred_.end()) {
to_process.push_back(true_partial_cnode);
partial_cnode_inferred_.insert(true_partial_cnode);
}
if (partial_cnode_inferred_.find(false_partial_cnode) == partial_cnode_inferred_.end()) {
to_process.push_back(false_partial_cnode);
partial_cnode_inferred_.insert(false_partial_cnode);
}
while (!to_process.empty()) {
@ -505,7 +501,8 @@ int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, Meta
case PrimitiveType_PartialFusion:
return InferPartialNode(node.get(), graph);
case PrimitiveType_Switch:
return InferSwitchNode(node, graph);
case PrimitiveType_SwitchLayer:
return InferSwitchOrSwitchLayerNode(node, graph);
default:
MS_LOG(ERROR) << "not able to call partial or call switch.";
return RET_ERROR;

View File

@ -31,9 +31,6 @@ using mindspore::converter::kFmkTypeTf;
using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
const constexpr int kTensorDataSize = 12;
const constexpr int kSwitchTrueIndex = 1;
const constexpr int kSwitchFalseIndex = 2;
struct InferTensor {
std::vector<uint32_t> next_nodes_;
std::vector<uint32_t> prev_nodes_;
@ -53,7 +50,7 @@ class InferShapePass : public GraphPass {
void AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes, uint32_t infer_node_index);
void ResetIncorrectTensorShape(MetaGraphT *graph);
int InferPartialNode(const CNodeT *partial_node, MetaGraphT *graph);
int InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph);
int InferSwitchOrSwitchLayerNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph);
int InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph);
int CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph);
void RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph);