!13075 move if and while from cops to converter

From: @lyvette
Reviewed-by: @hangangqiang
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-15 20:10:34 +08:00 committed by Gitee
commit e35dfbc006
16 changed files with 37 additions and 82 deletions

View File

@ -1,28 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include "ops/if.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameIf, If);
} // namespace ops
} // namespace mindspore

View File

@ -45,13 +45,11 @@ constexpr auto kBias = "bias";
constexpr auto kBidirectional = "bidirectional";
constexpr auto kBlockSize = "block_size";
constexpr auto kBlockShape = "block_shape";
constexpr auto kBodySubgraphIndex = "body_subgraph_index";
constexpr auto kCellClip = "cell_clip";
constexpr auto kCellDepth = "cell_depth";
constexpr auto kCenterPointBox = "center_point_box";
constexpr auto kClip = "clip";
constexpr auto kCondition = "condition";
constexpr auto kCondSubgraphIndex = "cond_subgraph_index";
constexpr auto kCrops = "crops";
constexpr auto kCustom = "custom";
constexpr auto kDampening = "dampening";

View File

@ -1006,11 +1006,6 @@ OP_SCHEMA_DEF(Unstack)
OP_ATTR_WITH_VALUE(axis, long, 0)
OP_SCHEMA_DEF_END(Unstack)
OP_SCHEMA_DEF(While)
OP_ATTR(cond_subgraph_index, long)
OP_ATTR(body_subgraph_index, long)
OP_SCHEMA_DEF_END(While)
OP_SCHEMA_DEF(Where)
OP_SCHEMA_DEF_END(Where)
@ -1020,9 +1015,6 @@ OP_SCHEMA_DEF_END(ZerosLike)
OP_SCHEMA_DEF(Select)
OP_SCHEMA_DEF_END(Select)
OP_SCHEMA_DEF(If)
OP_SCHEMA_DEF_END(If)
OP_SCHEMA_DEF(GRU)
OP_ATTR_WITH_VALUE(bidirectional, bool, false)
OP_SCHEMA_DEF_END(GRU)

View File

@ -175,7 +175,6 @@
#include "ops/unsqueeze.h"
#include "ops/unsorted_segment_sum.h"
#include "ops/where.h"
#include "ops/while.h"
#include "ops/zeros_like.h"
#include "ops/grad/activation_grad.h"
#include "ops/grad/add_grad.h"
@ -231,7 +230,6 @@
#include "ops/fusion/sub_fusion.h"
#include "ops/fusion/tile_fusion.h"
#include "ops/fusion/topk_fusion.h"
#include "ops/if.h"
#include "ops/gru.h"
#include "ops/non_zero.h"
#include "ops/invert_permutation.h"
@ -436,11 +434,9 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Unique);
FUNC_MSOP2SCHEMAOP_DECLARE(UnsortedSegmentSum);
FUNC_MSOP2SCHEMAOP_DECLARE(Unsqueeze);
FUNC_MSOP2SCHEMAOP_DECLARE(Unstack);
FUNC_MSOP2SCHEMAOP_DECLARE(While);
FUNC_MSOP2SCHEMAOP_DECLARE(Where);
FUNC_MSOP2SCHEMAOP_DECLARE(ZerosLike);
FUNC_MSOP2SCHEMAOP_DECLARE(Select);
FUNC_MSOP2SCHEMAOP_DECLARE(If);
FUNC_MSOP2SCHEMAOP_DECLARE(GRU);
FUNC_MSOP2SCHEMAOP_DECLARE(NonZero);
FUNC_MSOP2SCHEMAOP_DECLARE(InvertPermutation);

View File

@ -740,10 +740,6 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ZerosLike>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *WhilePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::While>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
@ -953,7 +949,6 @@ RegistryMSOps g_unsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum",
RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator);
RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator);
RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator);
RegistryMSOps g_whilePrimitiveCreatorRegistry("While", WhilePrimitiveCreator);
} // namespace lite
} // namespace mindspore

View File

@ -215,6 +215,7 @@ if(ENABLE_CONVERTER)
set(TEST_LITE_SRC
${TEST_LITE_SRC}
${TEST_CASE_TFLITE_PARSERS_SRC}
${LITE_DIR}/tools/converter/ops/while.cc
${LITE_DIR}/tools/common/protobuf_utils.cc
${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/tools/converter/anf_transform.cc

View File

@ -10,7 +10,9 @@ set(CCSRC_SRC
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc
)
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc

View File

@ -14,29 +14,24 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_IF_H_
#define MINDSPORE_CORE_OPS_IF_H_
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace ops {
namespace lite {
constexpr auto kNameIf = "If";
class If : public PrimitiveC {
public:
If() : PrimitiveC(kNameIf) {}
~If() = default;
MS_DECLARE_PARENT(If, PrimitiveC);
void Init() {}
};
AbstractBasePtr IfInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimIfPtr = std::shared_ptr<If>;
} // namespace ops
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_IF_H_
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_

View File

@ -15,13 +15,15 @@
*/
#include <vector>
#include "ops/while.h"
#include "ops/op_utils.h"
#include "tools/converter/ops/while.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
constexpr auto kCondSubgraphIndex = "cond_subgraph_index";
constexpr auto kBodySubgraphIndex = "body_subgraph_index";
namespace mindspore {
namespace ops {
namespace lite {
void While::Init(const int64_t cond_subgraph_index, const int64_t body_subgraph_index) {
this->set_cond_subgraph_index(cond_subgraph_index);
this->set_body_subgraph_index(body_subgraph_index);
@ -44,6 +46,7 @@ int64_t While::get_body_subgraph_index() const {
auto value_ptr = this->GetAttr(kBodySubgraphIndex);
return GetValue<int64_t>(value_ptr);
}
AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -58,6 +61,5 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
}
return std::make_shared<abstract::AbstractTuple>(output);
}
REGISTER_PRIMITIVE_C(kNameWhile, While);
} // namespace ops
} // namespace lite
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_WHILE_H_
#define MINDSPORE_CORE_OPS_WHILE_H_
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_
#include <vector>
#include <memory>
@ -23,8 +23,10 @@
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
using mindspore::ops::PrimitiveC;
namespace mindspore {
namespace ops {
namespace lite {
constexpr auto kNameWhile = "While";
class While : public PrimitiveC {
public:
@ -41,7 +43,7 @@ class While : public PrimitiveC {
AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimWhilePtr = std::shared_ptr<While>;
} // namespace ops
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_WHILE_H_
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_

View File

@ -17,12 +17,12 @@
#include "tools/converter/parser/onnx/onnx_if_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "ops/if.h"
#include "tools/converter/ops/if.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxIfParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<ops::If>();
auto prim = std::make_unique<If>();
return prim.release();
}

View File

@ -17,12 +17,12 @@
#include "tools/converter/parser/onnx/onnx_loop_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include "ops/while.h"
#include "tools/converter/ops/while.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *OnnxLoopParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
auto prim = std::make_unique<ops::While>();
auto prim = std::make_unique<While>();
return prim.release();
}

View File

@ -19,14 +19,14 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "ops/if.h"
#include "tools/converter/ops/if.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFIfParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::If>();
auto prim = std::make_unique<If>();
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {

View File

@ -19,14 +19,14 @@
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "ops/while.h"
#include "tools/converter/ops/while.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TFWhileParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<ops::While>();
auto prim = std::make_unique<While>();
*output_size = tf_op.input_size();
for (int i = 0; i < tf_op.input_size(); i++) {

View File

@ -18,13 +18,13 @@
#include "tools/converter/parser/tflite/tflite_while_parser.h"
#include <vector>
#include <memory>
#include "ops/while.h"
#include "tools/converter/ops/while.h"
namespace mindspore {
namespace lite {
ops::PrimitiveC *TfliteWhileParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model) {
auto prim = std::make_unique<ops::While>();
auto prim = std::make_unique<While>();
MS_ASSERT(tflite_op != nullptr);
const auto &tflite_attr = tflite_op->builtin_options.AsWhileOptions();

View File

@ -22,11 +22,11 @@
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
#include "ops/while.h"
#include "tools/converter/ops/while.h"
namespace {
mindspore::ValueNodePtr GetWhileAnfPrim() {
auto while_primc = std::make_shared<mindspore::ops::While>();
auto while_primc = std::make_shared<mindspore::lite::While>();
if (while_primc == nullptr) {
MS_LOG(ERROR) << "new while_primitive failed";
return nullptr;