forked from mindspore-Ecosystem/mindspore
support switch layer infershape
This commit is contained in:
parent
07f5702e48
commit
fc77c13f5f
|
@ -32,8 +32,6 @@
|
|||
#include "nnacl/op_base.h"
|
||||
|
||||
using mindspore::converter::kFmkTypeTf;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr int DEFAULT_DIM_VALUE = -1;
|
||||
constexpr size_t kInitialSize = 1024;
|
||||
|
@ -43,8 +41,11 @@ constexpr int kSwitchInputMinSize = 3;
|
|||
constexpr int kTypeIndex = 0;
|
||||
constexpr int kElementShapeIndex = 1;
|
||||
constexpr int kFirstElementShapeIndex = 2;
|
||||
constexpr int kTensorListDatasize = 3;
|
||||
|
||||
constexpr int kTensorListDataSize = 3;
|
||||
} // namespace
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
void FreeTensors(std::vector<Tensor *> *input_tensors, std::vector<Tensor *> *output_tensors) {
|
||||
if (input_tensors == nullptr) {
|
||||
return;
|
||||
|
@ -91,19 +92,19 @@ void ConvertTensorList(const MetaGraphT *graph, uint32_t index, bool *convert_su
|
|||
int *data = reinterpret_cast<int *>(tensorT->data.data());
|
||||
type = TypeId(data[kTypeIndex]);
|
||||
auto basic_data_size = tensorT->data.size() / sizeof(int);
|
||||
if (basic_data_size < static_cast<size_t>(kTensorListDatasize)) {
|
||||
if (basic_data_size < static_cast<size_t>(kTensorListDataSize)) {
|
||||
MS_LOG(ERROR) << "tensorlist data length illegal, which should be at least 3, now is " << basic_data_size;
|
||||
*convert_succ = false;
|
||||
return;
|
||||
}
|
||||
if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDatasize)) {
|
||||
if (data[kElementShapeIndex] < 0 || INT_ADD_OVERFLOW(data[kElementShapeIndex], kTensorListDataSize)) {
|
||||
MS_LOG(ERROR) << "int add overflow.";
|
||||
*convert_succ = false;
|
||||
return;
|
||||
}
|
||||
if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDatasize)) > basic_data_size) {
|
||||
if (static_cast<size_t>((data[kElementShapeIndex] + kTensorListDataSize)) > basic_data_size) {
|
||||
MS_LOG(ERROR) << "tensorlist data length illegal. current tensorlist data length should be at least "
|
||||
<< (data[kElementShapeIndex] + kTensorListDatasize) << ", but now is " << basic_data_size;
|
||||
<< (data[kElementShapeIndex] + kTensorListDataSize) << ", but now is " << basic_data_size;
|
||||
*convert_succ = false;
|
||||
return;
|
||||
}
|
||||
|
@ -311,7 +312,7 @@ int SetDataType(MetaGraphT *graph, const std::vector<Tensor *> &output_tensors,
|
|||
if (!tensor_list->tensors().empty()) {
|
||||
tensor_shape_dims = static_cast<int>(tensor_list->tensors().front()->shape().size());
|
||||
}
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDatasize), static_cast<int>(sizeof(int))),
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW((tensor_shape_dims + kTensorListDataSize), static_cast<int>(sizeof(int))),
|
||||
RET_ERROR, "int mul overflow");
|
||||
if (tensor_list->tensors_data_type() == kTypeUnknown) {
|
||||
if (!tensor_list->tensors().empty()) {
|
||||
|
@ -440,42 +441,37 @@ void InferShapePass::InitInferTensor(MetaGraphT *graph) {
|
|||
}
|
||||
}
|
||||
|
||||
int InferShapePass::InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph) {
|
||||
if (switch_node->inputIndex.size() < kSwitchInputMinSize) {
|
||||
MS_LOG(ERROR) << "switch node input size: " << switch_node->inputIndex.size() << " is less than three.";
|
||||
int InferShapePass::InferSwitchOrSwitchLayerNode(const std::unique_ptr<CNodeT> &aim_node, MetaGraphT *graph) {
|
||||
if (aim_node->inputIndex.size() < kSwitchInputMinSize) {
|
||||
MS_LOG(ERROR) << "switch or switch_layer node input size: " << aim_node->inputIndex.size() << " is less than 3.";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
size_t aim_node_input_size = aim_node->inputIndex.size();
|
||||
std::vector<uint32_t> all_partial_index{};
|
||||
for (size_t i = 1; i < aim_node_input_size; ++i) {
|
||||
all_partial_index.push_back(aim_node->inputIndex.at(i));
|
||||
}
|
||||
|
||||
std::vector<CNodeT *> all_partial_nodes{};
|
||||
for (auto &partial_index : all_partial_index) {
|
||||
for (auto &node : graph->nodes) {
|
||||
if (node->primitive->value.type != PrimitiveType_PartialFusion) {
|
||||
continue;
|
||||
}
|
||||
if (IsContain(node->outputIndex, partial_index)) {
|
||||
all_partial_nodes.push_back(node.get());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::deque<CNodeT *> to_process{};
|
||||
auto true_branch_output_index = switch_node->inputIndex.at(kSwitchTrueIndex);
|
||||
auto false_branch_output_index = switch_node->inputIndex.at(kSwitchFalseIndex);
|
||||
bool find_true_partial = false;
|
||||
bool find_false_partial = false;
|
||||
CNodeT *true_partial_cnode = nullptr;
|
||||
CNodeT *false_partial_cnode = nullptr;
|
||||
for (auto &node : graph->nodes) {
|
||||
if (node->primitive->value.type != PrimitiveType_PartialFusion) {
|
||||
continue;
|
||||
for (auto &partial_node : all_partial_nodes) {
|
||||
if (partial_cnode_inferred_.find(partial_node) == partial_cnode_inferred_.end()) {
|
||||
to_process.push_back(partial_node);
|
||||
partial_cnode_inferred_.insert(partial_node);
|
||||
}
|
||||
if (!find_true_partial && IsContain(node->outputIndex, true_branch_output_index)) {
|
||||
true_partial_cnode = node.get();
|
||||
find_true_partial = true;
|
||||
}
|
||||
if (!find_false_partial && IsContain(node->outputIndex, false_branch_output_index)) {
|
||||
false_partial_cnode = node.get();
|
||||
find_false_partial = true;
|
||||
}
|
||||
if (find_true_partial && find_false_partial) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (partial_cnode_inferred_.find(true_partial_cnode) == partial_cnode_inferred_.end()) {
|
||||
to_process.push_back(true_partial_cnode);
|
||||
partial_cnode_inferred_.insert(true_partial_cnode);
|
||||
}
|
||||
if (partial_cnode_inferred_.find(false_partial_cnode) == partial_cnode_inferred_.end()) {
|
||||
to_process.push_back(false_partial_cnode);
|
||||
partial_cnode_inferred_.insert(false_partial_cnode);
|
||||
}
|
||||
|
||||
while (!to_process.empty()) {
|
||||
|
@ -505,7 +501,8 @@ int InferShapePass::InferCallNode(const std::unique_ptr<CNodeT> &call_node, Meta
|
|||
case PrimitiveType_PartialFusion:
|
||||
return InferPartialNode(node.get(), graph);
|
||||
case PrimitiveType_Switch:
|
||||
return InferSwitchNode(node, graph);
|
||||
case PrimitiveType_SwitchLayer:
|
||||
return InferSwitchOrSwitchLayerNode(node, graph);
|
||||
default:
|
||||
MS_LOG(ERROR) << "not able to call partial or call switch.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -31,9 +31,6 @@ using mindspore::converter::kFmkTypeTf;
|
|||
using mindspore::schema::TensorT;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
const constexpr int kTensorDataSize = 12;
|
||||
const constexpr int kSwitchTrueIndex = 1;
|
||||
const constexpr int kSwitchFalseIndex = 2;
|
||||
struct InferTensor {
|
||||
std::vector<uint32_t> next_nodes_;
|
||||
std::vector<uint32_t> prev_nodes_;
|
||||
|
@ -53,7 +50,7 @@ class InferShapePass : public GraphPass {
|
|||
void AddOutputNodes(MetaGraphT *graph, std::vector<uint32_t> *infer_node_indexes, uint32_t infer_node_index);
|
||||
void ResetIncorrectTensorShape(MetaGraphT *graph);
|
||||
int InferPartialNode(const CNodeT *partial_node, MetaGraphT *graph);
|
||||
int InferSwitchNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph);
|
||||
int InferSwitchOrSwitchLayerNode(const std::unique_ptr<CNodeT> &switch_node, MetaGraphT *graph);
|
||||
int InferCallNode(const std::unique_ptr<CNodeT> &call_node, MetaGraphT *graph);
|
||||
int CopyPartialShapeToSubGraph(const CNodeT *partial_node, MetaGraphT *graph);
|
||||
void RestoreSubGraphInput(const CNodeT *partial_node, MetaGraphT *graph);
|
||||
|
|
Loading…
Reference in New Issue