forked from mindspore-Ecosystem/mindspore
!13075 move if and while from cops to converter
From: @lyvette Reviewed-by: @hangangqiang Signed-off-by:
This commit is contained in:
commit
e35dfbc006
|
@ -1,28 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/if.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
REGISTER_PRIMITIVE_C(kNameIf, If);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -45,13 +45,11 @@ constexpr auto kBias = "bias";
|
|||
constexpr auto kBidirectional = "bidirectional";
|
||||
constexpr auto kBlockSize = "block_size";
|
||||
constexpr auto kBlockShape = "block_shape";
|
||||
constexpr auto kBodySubgraphIndex = "body_subgraph_index";
|
||||
constexpr auto kCellClip = "cell_clip";
|
||||
constexpr auto kCellDepth = "cell_depth";
|
||||
constexpr auto kCenterPointBox = "center_point_box";
|
||||
constexpr auto kClip = "clip";
|
||||
constexpr auto kCondition = "condition";
|
||||
constexpr auto kCondSubgraphIndex = "cond_subgraph_index";
|
||||
constexpr auto kCrops = "crops";
|
||||
constexpr auto kCustom = "custom";
|
||||
constexpr auto kDampening = "dampening";
|
||||
|
|
|
@ -1006,11 +1006,6 @@ OP_SCHEMA_DEF(Unstack)
|
|||
OP_ATTR_WITH_VALUE(axis, long, 0)
|
||||
OP_SCHEMA_DEF_END(Unstack)
|
||||
|
||||
OP_SCHEMA_DEF(While)
|
||||
OP_ATTR(cond_subgraph_index, long)
|
||||
OP_ATTR(body_subgraph_index, long)
|
||||
OP_SCHEMA_DEF_END(While)
|
||||
|
||||
OP_SCHEMA_DEF(Where)
|
||||
OP_SCHEMA_DEF_END(Where)
|
||||
|
||||
|
@ -1020,9 +1015,6 @@ OP_SCHEMA_DEF_END(ZerosLike)
|
|||
OP_SCHEMA_DEF(Select)
|
||||
OP_SCHEMA_DEF_END(Select)
|
||||
|
||||
OP_SCHEMA_DEF(If)
|
||||
OP_SCHEMA_DEF_END(If)
|
||||
|
||||
OP_SCHEMA_DEF(GRU)
|
||||
OP_ATTR_WITH_VALUE(bidirectional, bool, false)
|
||||
OP_SCHEMA_DEF_END(GRU)
|
||||
|
|
|
@ -175,7 +175,6 @@
|
|||
#include "ops/unsqueeze.h"
|
||||
#include "ops/unsorted_segment_sum.h"
|
||||
#include "ops/where.h"
|
||||
#include "ops/while.h"
|
||||
#include "ops/zeros_like.h"
|
||||
#include "ops/grad/activation_grad.h"
|
||||
#include "ops/grad/add_grad.h"
|
||||
|
@ -231,7 +230,6 @@
|
|||
#include "ops/fusion/sub_fusion.h"
|
||||
#include "ops/fusion/tile_fusion.h"
|
||||
#include "ops/fusion/topk_fusion.h"
|
||||
#include "ops/if.h"
|
||||
#include "ops/gru.h"
|
||||
#include "ops/non_zero.h"
|
||||
#include "ops/invert_permutation.h"
|
||||
|
@ -436,11 +434,9 @@ FUNC_MSOP2SCHEMAOP_DECLARE(Unique);
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(UnsortedSegmentSum);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Unsqueeze);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Unstack);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(While);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Where);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(ZerosLike);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Select);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(If);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(GRU);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(NonZero);
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(InvertPermutation);
|
||||
|
|
|
@ -740,10 +740,6 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) {
|
|||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::ZerosLike>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
schema::PrimitiveT *WhilePrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::While>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
|
||||
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
|
||||
|
@ -953,7 +949,6 @@ RegistryMSOps g_unsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum",
|
|||
RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator);
|
||||
RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator);
|
||||
RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator);
|
||||
RegistryMSOps g_whilePrimitiveCreatorRegistry("While", WhilePrimitiveCreator);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -215,6 +215,7 @@ if(ENABLE_CONVERTER)
|
|||
set(TEST_LITE_SRC
|
||||
${TEST_LITE_SRC}
|
||||
${TEST_CASE_TFLITE_PARSERS_SRC}
|
||||
${LITE_DIR}/tools/converter/ops/while.cc
|
||||
${LITE_DIR}/tools/common/protobuf_utils.cc
|
||||
${LITE_DIR}/tools/converter/optimizer.cc
|
||||
${LITE_DIR}/tools/converter/anf_transform.cc
|
||||
|
|
|
@ -10,7 +10,9 @@ set(CCSRC_SRC
|
|||
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
||||
|
||||
file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../flag/flag_parser.cc
|
||||
|
|
|
@ -14,29 +14,24 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_IF_H_
|
||||
#define MINDSPORE_CORE_OPS_IF_H_
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace lite {
|
||||
constexpr auto kNameIf = "If";
|
||||
class If : public PrimitiveC {
|
||||
public:
|
||||
If() : PrimitiveC(kNameIf) {}
|
||||
~If() = default;
|
||||
MS_DECLARE_PARENT(If, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr IfInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimIfPtr = std::shared_ptr<If>;
|
||||
} // namespace ops
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_IF_H_
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_IF_H_
|
|
@ -15,13 +15,15 @@
|
|||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "ops/while.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "tools/converter/ops/while.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
constexpr auto kCondSubgraphIndex = "cond_subgraph_index";
|
||||
constexpr auto kBodySubgraphIndex = "body_subgraph_index";
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace lite {
|
||||
void While::Init(const int64_t cond_subgraph_index, const int64_t body_subgraph_index) {
|
||||
this->set_cond_subgraph_index(cond_subgraph_index);
|
||||
this->set_body_subgraph_index(body_subgraph_index);
|
||||
|
@ -44,6 +46,7 @@ int64_t While::get_body_subgraph_index() const {
|
|||
auto value_ptr = this->GetAttr(kBodySubgraphIndex);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -58,6 +61,5 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
}
|
||||
return std::make_shared<abstract::AbstractTuple>(output);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameWhile, While);
|
||||
} // namespace ops
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_WHILE_H_
|
||||
#define MINDSPORE_CORE_OPS_WHILE_H_
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
|
@ -23,8 +23,10 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace lite {
|
||||
constexpr auto kNameWhile = "While";
|
||||
class While : public PrimitiveC {
|
||||
public:
|
||||
|
@ -41,7 +43,7 @@ class While : public PrimitiveC {
|
|||
AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimWhilePtr = std::shared_ptr<While>;
|
||||
} // namespace ops
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_WHILE_H_
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_WHILE_H_
|
|
@ -17,12 +17,12 @@
|
|||
#include "tools/converter/parser/onnx/onnx_if_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include "ops/if.h"
|
||||
#include "tools/converter/ops/if.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *OnnxIfParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<ops::If>();
|
||||
auto prim = std::make_unique<If>();
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#include "tools/converter/parser/onnx/onnx_loop_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include "ops/while.h"
|
||||
#include "tools/converter/ops/while.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *OnnxLoopParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<ops::While>();
|
||||
auto prim = std::make_unique<While>();
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "ops/if.h"
|
||||
#include "tools/converter/ops/if.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFIfParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::If>();
|
||||
auto prim = std::make_unique<If>();
|
||||
|
||||
*output_size = 1;
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "ops/while.h"
|
||||
#include "tools/converter/ops/while.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFWhileParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::While>();
|
||||
auto prim = std::make_unique<While>();
|
||||
|
||||
*output_size = tf_op.input_size();
|
||||
for (int i = 0; i < tf_op.input_size(); i++) {
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
#include "tools/converter/parser/tflite/tflite_while_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/while.h"
|
||||
#include "tools/converter/ops/while.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TfliteWhileParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto prim = std::make_unique<ops::While>();
|
||||
auto prim = std::make_unique<While>();
|
||||
|
||||
MS_ASSERT(tflite_op != nullptr);
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsWhileOptions();
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#include "ops/make_tuple.h"
|
||||
#include "ops/return.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "ops/while.h"
|
||||
#include "tools/converter/ops/while.h"
|
||||
|
||||
namespace {
|
||||
mindspore::ValueNodePtr GetWhileAnfPrim() {
|
||||
auto while_primc = std::make_shared<mindspore::ops::While>();
|
||||
auto while_primc = std::make_shared<mindspore::lite::While>();
|
||||
if (while_primc == nullptr) {
|
||||
MS_LOG(ERROR) << "new while_primitive failed";
|
||||
return nullptr;
|
||||
|
|
Loading…
Reference in New Issue