forked from mindspore-Ecosystem/mindspore
!12289 [MSLITE] remove support train in convert
From: @zhengjun10 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
148da5223d
|
@ -114,16 +114,6 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
|
|||
max_dim = dim;
|
||||
}
|
||||
}
|
||||
#ifndef SUPPORT_TRAIN
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
size_t shift = max_dims - inputs.at(i)->shape().size();
|
||||
size_t dim = (i < shift) ? 1 : inputs.at(i)->shape().at(d);
|
||||
if ((dim != max_dim) && (dim != 1)) {
|
||||
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
output->shape()[d] = max_dim; // set the biggest dimension in the output tensor
|
||||
}
|
||||
|
||||
|
|
|
@ -149,13 +149,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|||
attr->padRight = pad_list.at(3);
|
||||
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"));
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (train_flag()) {
|
||||
attr->dilateH = dilation.at(2);
|
||||
attr->dilateW = dilation.at(3);
|
||||
#else
|
||||
} else {
|
||||
attr->dilateH = dilation.at(0);
|
||||
attr->dilateW = dilation.at(1);
|
||||
#endif
|
||||
}
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
|
||||
attr->kernelH = kernel_size.at(0);
|
||||
attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0);
|
||||
|
|
|
@ -19,9 +19,6 @@
|
|||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include <tuple>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -56,20 +53,16 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
}
|
||||
string paddingmode = "REFLECT";
|
||||
if (prim.GetAttr("mode") == nullptr) {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (prim.name() == "Pad") {
|
||||
paddingmode = "CONSTANT";
|
||||
} else {
|
||||
#endif
|
||||
MS_LOG(ERROR) << "get mode failed!";
|
||||
delete this->primitive_;
|
||||
delete attr;
|
||||
this->primitive_ = nullptr;
|
||||
attr = nullptr;
|
||||
return RET_ERROR;
|
||||
#ifdef SUPPORT_TRAIN
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
paddingmode = GetValue<string>(prim.GetAttr("mode"));
|
||||
}
|
||||
|
@ -77,7 +70,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
attr->paddingMode = schema::PaddingMode_REFLECT;
|
||||
} else if (paddingmode == "SYMMETRIC") {
|
||||
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
|
||||
#ifdef SUPPORT_TRAIN
|
||||
} else if (paddingmode == "CONSTANT") {
|
||||
attr->paddingMode = schema::PaddingMode_CONSTANT;
|
||||
if (prim.GetAttr("paddings") != nullptr) {
|
||||
|
@ -91,7 +83,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
|
|||
attr->paddings.push_back(i);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "model type not supported!";
|
||||
delete this->primitive_;
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/ops/assert_op.h"
|
||||
#include "src/ops/space_to_batch.h"
|
||||
|
@ -175,8 +174,6 @@
|
|||
#include "src/ops/uniform_real.h"
|
||||
#include "src/ops/rank.h"
|
||||
#include "src/ops/is_finite.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
#include "src/ops/activation_grad.h"
|
||||
#include "src/ops/apply_momentum.h"
|
||||
|
@ -210,7 +207,6 @@
|
|||
#include "src/ops/sigmoid_cross_entropy_with_logits_grad.h"
|
||||
#include "src/ops/strided_slice_grad.h"
|
||||
#endif
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -513,13 +509,14 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
|
|||
|
||||
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
|
||||
std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
const schema::QuantType &quantType) {
|
||||
const schema::QuantType &quantType, bool train_flag = false) {
|
||||
auto primc = std::make_shared<T>();
|
||||
if (primc == nullptr) {
|
||||
MS_LOG(ERROR) << "make_shared PrimitiveC failed";
|
||||
return nullptr;
|
||||
}
|
||||
primc->set_quant_type(quantType);
|
||||
primc->set_train_flag(train_flag);
|
||||
auto ret = primc->UnPackAttr(prim, inputs);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "UnPackAttr failed";
|
||||
|
@ -529,7 +526,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, cons
|
|||
}
|
||||
|
||||
std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
const schema::QuantType &quantType) {
|
||||
const schema::QuantType &quantType, bool train_flag) {
|
||||
const auto &op_type = prim.name();
|
||||
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") {
|
||||
return NewPrimitiveC<Activation>(prim, inputs, quantType);
|
||||
|
@ -544,7 +541,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
} else if (op_type == "Concat") {
|
||||
return NewPrimitiveC<Concat>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2D") {
|
||||
return NewPrimitiveC<Conv2D>(prim, inputs, quantType);
|
||||
return NewPrimitiveC<Conv2D>(prim, inputs, quantType, train_flag);
|
||||
} else if (op_type == "Cos") {
|
||||
return NewPrimitiveC<Cos>(prim, inputs, quantType);
|
||||
} else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") {
|
||||
|
@ -664,7 +661,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
} else if (op_type == "Range") {
|
||||
return NewPrimitiveC<Range>(prim, inputs, quantType);
|
||||
} else if (op_type == "Tile") {
|
||||
return NewPrimitiveC<Tile>(prim, inputs, quantType);
|
||||
return NewPrimitiveC<Tile>(prim, inputs, quantType, train_flag);
|
||||
} else if (op_type == "GatherNd") {
|
||||
return NewPrimitiveC<GatherNd>(prim, inputs, quantType);
|
||||
} else if (op_type == "Square") {
|
||||
|
@ -685,7 +682,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<ArgMax>(prim, inputs, quantType);
|
||||
} else if (op_type == "Gelu") {
|
||||
return NewPrimitiveC<GeLU>(prim, inputs, quantType);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
|
||||
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
|
||||
} else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") {
|
||||
|
@ -706,7 +702,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2DBackpropFilter") {
|
||||
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
|
||||
} else if (op_type == "Conv2DBackpropInput") {
|
||||
} else if (op_type == "Conv2DBackpropInput" && train_flag) {
|
||||
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
|
||||
} else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) {
|
||||
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
|
||||
|
@ -748,10 +744,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType);
|
||||
} else if (op_type == "AbsGrad") {
|
||||
return NewPrimitiveC<AbsGrad>(prim, inputs, quantType);
|
||||
#else
|
||||
} else if (op_type == "Conv2DBackpropInput") {
|
||||
} else if (op_type == "Conv2DBackpropInput" && !train_flag) {
|
||||
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type;
|
||||
return nullptr;
|
||||
|
@ -1065,7 +1059,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) UniformReal(primitive);
|
||||
case schema::PrimitiveType_Rank:
|
||||
return new (std::nothrow) Rank(primitive);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
return new (std::nothrow) ActivationGrad(primitive);
|
||||
case schema::PrimitiveType_PoolingGrad:
|
||||
|
@ -1140,7 +1133,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive);
|
||||
case schema::PrimitiveType_StridedSliceGrad:
|
||||
return new (std::nothrow) StridedSliceGrad(primitive);
|
||||
#endif
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);
|
||||
break;
|
||||
|
@ -1170,6 +1162,10 @@ bool PrimitiveC::infer_flag() const { return this->infer_flag_; }
|
|||
|
||||
void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; }
|
||||
|
||||
bool PrimitiveC::train_flag() const { return this->train_flag_; }
|
||||
|
||||
void PrimitiveC::set_train_flag(bool flag) { this->train_flag_ = flag; }
|
||||
|
||||
int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
||||
auto input = inputs.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
|
|
@ -133,6 +133,10 @@ class PrimitiveC : public mindspore::Primitive {
|
|||
|
||||
void set_infer_flag(bool flag);
|
||||
|
||||
bool train_flag() const;
|
||||
|
||||
void set_train_flag(bool flag);
|
||||
|
||||
static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); }
|
||||
|
||||
static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive);
|
||||
|
@ -140,7 +144,7 @@ class PrimitiveC : public mindspore::Primitive {
|
|||
static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data);
|
||||
|
||||
static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
const schema::QuantType &quantType);
|
||||
const schema::QuantType &quantType, bool train_flag = false);
|
||||
void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
void FillDefaultInputQuantParamIfNeed(const size_t &inputSize);
|
||||
void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
|
@ -159,6 +163,7 @@ class PrimitiveC : public mindspore::Primitive {
|
|||
bool infer_flag_ = true;
|
||||
int op_type_ = OP_TYPE_NOT_SET;
|
||||
bool enable_huffman_code_ = false;
|
||||
bool train_flag_ = false;
|
||||
};
|
||||
std::shared_ptr<PrimitiveC> GetReturnPrim();
|
||||
|
||||
|
@ -179,6 +184,10 @@ class PrimitiveC {
|
|||
|
||||
void set_infer_flag(bool flag);
|
||||
|
||||
bool train_flag() const;
|
||||
|
||||
void set_train_flag(bool flag);
|
||||
|
||||
virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
|
||||
|
||||
int Type() const;
|
||||
|
@ -238,6 +247,7 @@ class PrimitiveC {
|
|||
bool infer_flag_ = true;
|
||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||
int op_type_ = OP_TYPE_NOT_SET;
|
||||
bool train_flag_ = false;
|
||||
};
|
||||
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
|
||||
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);
|
||||
|
|
|
@ -159,7 +159,7 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
} else {
|
||||
multiples = GetMultiples();
|
||||
}
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (train_flag()) {
|
||||
const size_t in_dims = input->shape().size();
|
||||
const size_t delta_dims = in_dims - multiples.size();
|
||||
|
||||
|
@ -172,7 +172,7 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
int tmp = input->shape().at(i) * (multiples[i - delta_dims]);
|
||||
out_shape.push_back(tmp);
|
||||
}
|
||||
#else
|
||||
} else {
|
||||
std::vector<int> dims = GetDims();
|
||||
if (inputs_.size() == 2 && dims.empty()) {
|
||||
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
|
||||
|
@ -193,7 +193,7 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
}
|
||||
out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
}
|
||||
|
||||
RemoveIfMakeTuple(cnode);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (train_flag) {
|
||||
RemoveIfDepend(cnode);
|
||||
#endif
|
||||
if (primitive_c->Type() == schema::PrimitiveType_Depend ||
|
||||
primitive_c->Type() == schema::PrimitiveType_ControlDepend) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
|
||||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
|
||||
#endif
|
||||
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
||||
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
|
||||
bool train_flag) {
|
||||
static int subgraph_index = 0;
|
||||
this->train_flag = train_flag;
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|||
std::string input_name = input_anode->fullname_with_scope();
|
||||
auto input_cnode = utils::cast<CNodePtr>(input_anode);
|
||||
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
|
||||
#ifndef SUPPORT_TRAIN
|
||||
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
||||
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
||||
}
|
||||
#else
|
||||
bool found = false;
|
||||
if (node_id_map_.find(input_name) != node_id_map_.end()) {
|
||||
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
|
||||
found = true;
|
||||
}
|
||||
|
||||
if (found == false) {
|
||||
if (!found) {
|
||||
auto input_index_key = input_name + "_o:" + std::to_string(0);
|
||||
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
|
||||
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
auto inputs = input_cnode->inputs();
|
||||
|
||||
|
@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
|
|||
: GetValue<int>(value_node->value()));
|
||||
auto iter = node_id_map_.find(input_index_key);
|
||||
if (iter == node_id_map_.end()) {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
|
||||
iter = node_id_map_.find(input_index_key);
|
||||
if (iter == node_id_map_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
||||
return RET_ERROR;
|
||||
}
|
||||
#else
|
||||
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
|
||||
return RET_ERROR;
|
||||
#endif
|
||||
}
|
||||
output_cnode->inputIndex.emplace_back(iter->second);
|
||||
}
|
||||
|
@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc
|
|||
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
|
||||
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
||||
(*paramTensor)->dims = dims;
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1};
|
||||
#endif
|
||||
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
|
||||
(*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
(*paramTensor)->data.resize(data->Size());
|
||||
|
@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu
|
|||
(*paramTensor)->format = schema::Format(valueLite->format());
|
||||
(*paramTensor)->dataType = valueLite->tensor_type();
|
||||
(*paramTensor)->dims = valueLite->tensor_shape();
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if ((*paramTensor)->dims.size() == 0) {
|
||||
|
||||
if (train_flag && (*paramTensor)->dims.empty()) {
|
||||
(*paramTensor)->dims = {1};
|
||||
}
|
||||
#endif
|
||||
|
||||
ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(),
|
||||
valueLite->tensor_size());
|
||||
if (ret != EOK) {
|
||||
|
@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
auto paramTensor = std::make_unique<schema::TensorT>();
|
||||
auto value = valueNode->value();
|
||||
int ret = RET_OK;
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (train_flag) {
|
||||
paramTensor->name = valueNode->fullname_with_scope();
|
||||
#endif
|
||||
}
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT);
|
||||
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
|
||||
|
@ -797,7 +786,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
}
|
||||
msTensor->nodeType = schema::NodeType_CNode;
|
||||
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (train_flag) {
|
||||
std::string name = cnode_name + "_o:" + std::to_string(i);
|
||||
node_id_map_[name] = meta_graphT->allTensors.size();
|
||||
meta_graphT->allTensors.emplace_back(msTensor);
|
||||
|
@ -805,7 +794,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
|
||||
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
|
||||
break;
|
||||
#else
|
||||
} else {
|
||||
if (elements.size() == 1) {
|
||||
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
|
||||
msTensor->name = cnode_name;
|
||||
|
@ -834,7 +823,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto ms_tensor = new (std::nothrow) schema::TensorT();
|
||||
|
@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node
|
|||
}
|
||||
}
|
||||
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) {
|
||||
AnfExporter anf_exporter;
|
||||
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);
|
||||
return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag);
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -35,7 +35,8 @@ class AnfExporter {
|
|||
public:
|
||||
AnfExporter() = default;
|
||||
virtual ~AnfExporter() = default;
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
|
||||
bool train_flag = false);
|
||||
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *fb_node);
|
||||
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
|
@ -91,11 +92,13 @@ class AnfExporter {
|
|||
std::vector<schema::CNodeT *> graph_input_nodes_;
|
||||
std::map<FuncGraphPtr, int> fg_subgraph_map;
|
||||
uint32_t node_idx = 0;
|
||||
bool train_flag = false;
|
||||
};
|
||||
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
|
||||
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify
|
||||
// the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple
|
||||
// and clear.
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
|
||||
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
|
||||
bool train_flag = false);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_
|
||||
|
|
|
@ -855,14 +855,14 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model
|
|||
}
|
||||
|
||||
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
|
||||
#if SUPPORT_TRAIN
|
||||
if (flag->trainModel) {
|
||||
func_graph_ = LoadMindIR(flag->modelFile, true);
|
||||
if (func_graph_ != nullptr) {
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format";
|
||||
}
|
||||
#endif
|
||||
}
|
||||
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
|
||||
if (onnx_model_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
|
||||
|
|
|
@ -24,9 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
schema::PrimitiveType_Conv2DGradFilter,
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DGradFilter,
|
||||
schema::PrimitiveType_Conv2DGradInput,
|
||||
schema::PrimitiveType_GroupConv2DGradInput,
|
||||
schema::PrimitiveType_PoolingGrad,
|
||||
|
@ -35,7 +33,6 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
|||
schema::PrimitiveType_ApplyMomentum,
|
||||
schema::PrimitiveType_Sgd,
|
||||
schema::PrimitiveType_Adam,
|
||||
#endif
|
||||
schema::PrimitiveType_Conv2D,
|
||||
schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D,
|
||||
|
@ -52,11 +49,15 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
|||
schema::PrimitiveType_TopK};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter,
|
||||
schema::PrimitiveType_BNGrad
|
||||
#endif
|
||||
};
|
||||
schema::PrimitiveType_BNGrad};
|
||||
|
||||
// index {} mean all inputs need insert
|
||||
static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = {
|
||||
{schema::PrimitiveType_BNGrad, {0, 1}},
|
||||
{schema::PrimitiveType_ApplyMomentum, {3}},
|
||||
{schema::PrimitiveType_Sgd, {1}},
|
||||
{schema::PrimitiveType_Adam, {9}}};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
|
||||
|
@ -133,18 +134,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
|
|||
schema::PrimitiveType_L2Norm};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_ActivationGrad
|
||||
#else
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop,
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum
|
||||
#endif
|
||||
};
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad};
|
||||
|
||||
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
|
||||
|
||||
|
@ -156,6 +149,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList;
|
|||
|
||||
std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
|
||||
|
||||
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
|
||||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }
|
||||
|
|
|
@ -62,6 +62,8 @@ std::vector<schema::PrimitiveType> GetNhwcOpList();
|
|||
|
||||
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
|
||||
|
||||
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
|
||||
|
||||
std::vector<schema::PrimitiveType> Getfp32FullOpList();
|
||||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
|
||||
|
|
|
@ -101,12 +101,6 @@ set(LITE_SRC
|
|||
${SRC_DIR}/dequant.cc
|
||||
${SRC_DIR}/huffman_decode.cc
|
||||
)
|
||||
if(SUPPORT_TRAIN)
|
||||
set(LITE_SRC
|
||||
${LITE_SRC}
|
||||
)
|
||||
|
||||
endif()
|
||||
set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm)
|
||||
file(GLOB KERNEL_SRC
|
||||
${ARM_DIR}/base/*.cc
|
||||
|
|
|
@ -177,6 +177,7 @@ int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const conve
|
|||
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
|
||||
mindir_adjust_pass->SetFmkType(config->fmk);
|
||||
mindir_adjust_pass->SetQuantType(config->quantType);
|
||||
mindir_adjust_pass->SetTrainFlag(config->trainModel);
|
||||
if (!mindir_adjust_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "mindir adjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
|
|
|
@ -88,7 +88,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
}
|
||||
|
||||
// anf -- fb
|
||||
auto meta_graph = Export(graph);
|
||||
auto meta_graph = Export(graph, false, false, flag->trainModel);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||
return nullptr;
|
||||
|
|
|
@ -48,21 +48,8 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT
|
|||
FormatTransNodeType *afterNodeType) {
|
||||
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
|
||||
return RET_NO_CHANGE;
|
||||
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
*beforeNodeType = kNCHW2NHWC;
|
||||
*afterNodeType = kNHWC2NCHW;
|
||||
return RET_OK;
|
||||
} else if (fmkType == converter::FmkType_MS) {
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
*beforeNodeType = kNCHW2NHWC;
|
||||
*afterNodeType = kNHWC2NCHW;
|
||||
return RET_OK;
|
||||
} else if (fmkType == converter::FmkType_ONNX) {
|
||||
} else if (fmkType == converter::FmkType_CAFFE || fmkType == converter::FmkType_MS ||
|
||||
fmkType == converter::FmkType_ONNX) {
|
||||
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
|
@ -173,11 +160,19 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
|
||||
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
|
||||
}
|
||||
#ifdef SUPPORT_TRAIN
|
||||
if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) {
|
||||
int idx_num = node->inputIndex.size();
|
||||
if (GetCNodeTType(**iter) == schema::PrimitiveType_BNGrad) idx_num = 2;
|
||||
for (int i = 0; i < idx_num; i++) {
|
||||
auto specInsertIndexes = GetExtNhwcIndexes();
|
||||
auto opType = GetCNodeTType(**iter);
|
||||
if (specInsertIndexes.find(opType) != specInsertIndexes.end()) {
|
||||
for (auto insert_index : specInsertIndexes[opType]) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else if (IsContain(GetNhwcAllInputOpList(), opType)) {
|
||||
auto input_size = node->inputIndex.size();
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
|
||||
|
@ -185,23 +180,8 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
int idx = 0;
|
||||
if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3;
|
||||
if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1;
|
||||
if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9;
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
#else
|
||||
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#endif
|
||||
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
|
||||
|
|
|
@ -194,7 +194,6 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc
|
|||
if (!IsContain(bfs_queue, input_node_index)) {
|
||||
bfs_queue.emplace_back(input_node_index);
|
||||
}
|
||||
// todo multi output,other edge need insert nh2nc node
|
||||
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
|
||||
if (pre_node_output_indexs.size() != 1) {
|
||||
if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) {
|
||||
|
|
|
@ -29,92 +29,58 @@ std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
|
|||
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
|
||||
} // namespace
|
||||
namespace lite {
|
||||
bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count,
|
||||
FormatTransNodeType *trans_type) {
|
||||
for (auto input_node_index : node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > input_node_index);
|
||||
auto &pre_node = graph->nodes.at(input_node_index);
|
||||
MS_ASSERT(pre_node != nullptr);
|
||||
MS_ASSERT(pre_node->primitive != nullptr);
|
||||
MS_ASSERT(pre_node->primitive->value != nullptr);
|
||||
if (*trans_type == kNONE) {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr);
|
||||
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
*trans_type = kNCHW2NHWC;
|
||||
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
*trans_type = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
(*has_trans_count)++;
|
||||
}
|
||||
} else {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
auto cur_type = kNONE;
|
||||
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
cur_type = kNCHW2NHWC;
|
||||
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
cur_type = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (*trans_type != cur_type) {
|
||||
return false;
|
||||
} else {
|
||||
(*has_trans_count)++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
|
||||
pre_type_ = kNONE;
|
||||
size_t has_trans_count = 0;
|
||||
auto can_fusion = true;
|
||||
for (auto input_node_index : input_node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > input_node_index);
|
||||
auto &pre_node = graph->nodes.at(input_node_index);
|
||||
MS_ASSERT(pre_node != nullptr);
|
||||
MS_ASSERT(pre_node->primitive != nullptr);
|
||||
MS_ASSERT(pre_node->primitive->value != nullptr);
|
||||
if (pre_type_ == kNONE) {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr);
|
||||
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
pre_type_ = kNCHW2NHWC;
|
||||
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
pre_type_ = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
has_trans_count++;
|
||||
}
|
||||
} else {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
auto cur_type = kNONE;
|
||||
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
cur_type = kNCHW2NHWC;
|
||||
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
cur_type = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (pre_type_ != cur_type) {
|
||||
can_fusion = false;
|
||||
break;
|
||||
} else {
|
||||
has_trans_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!can_fusion) {
|
||||
if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) {
|
||||
return false;
|
||||
}
|
||||
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
|
||||
post_type_ = kNONE;
|
||||
for (auto output_node_index : output_node_indexes) {
|
||||
MS_ASSERT(graph->nodes.size() > output_node_index);
|
||||
auto &post_node = graph->nodes.at(output_node_index);
|
||||
MS_ASSERT(post_node != nullptr);
|
||||
MS_ASSERT(post_node->primitive != nullptr);
|
||||
MS_ASSERT(post_node->primitive->value != nullptr);
|
||||
if (post_type_ == kNONE) {
|
||||
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
post_type_ = kNCHW2NHWC;
|
||||
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
post_type_ = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
has_trans_count++;
|
||||
}
|
||||
} else {
|
||||
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
|
||||
auto cur_type = kNONE;
|
||||
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
|
||||
cur_type = kNCHW2NHWC;
|
||||
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
|
||||
cur_type = kNHWC2NCHW;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (post_type_ != cur_type) {
|
||||
can_fusion = false;
|
||||
break;
|
||||
} else {
|
||||
has_trans_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!can_fusion) {
|
||||
if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) {
|
||||
return false;
|
||||
}
|
||||
if (pre_type_ == kNONE && post_type_ == kNONE) {
|
||||
|
@ -136,10 +102,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|||
if (GetCNodeTType(*node) == schema::PrimitiveType_Split) {
|
||||
return has_trans_count >= half_count;
|
||||
}
|
||||
can_fusion = has_trans_count > half_count;
|
||||
return can_fusion;
|
||||
return has_trans_count > half_count;
|
||||
}
|
||||
|
||||
STATUS TransOpInsertPass::FindOutTransType() {
|
||||
pre_insert_trans_type_ = kNHWC2NCHW;
|
||||
post_insert_trans_type_ = kNHWC2NCHW;
|
||||
|
@ -153,7 +117,7 @@ STATUS TransOpInsertPass::FindOutTransType() {
|
|||
MS_ASSERT(false);
|
||||
} else {
|
||||
if (pre_type_ == post_type_) {
|
||||
MS_LOG(ERROR) << "Unknow error";
|
||||
MS_LOG(ERROR) << "Unknown error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
|
||||
|
@ -200,13 +164,6 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
|
|||
STATUS status = RET_OK;
|
||||
auto input_tensor_size = (*iter)->inputIndex.size();
|
||||
for (size_t i = 0; i < input_tensor_size; i++) {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
auto &tensor = graph->allTensors.at((*iter)->inputIndex[i]);
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
if (tensor->nodeType == schema::NodeType_ValueNode) {
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
|
||||
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
|
||||
continue;
|
||||
|
|
|
@ -37,6 +37,15 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) {
|
|||
return para_value_lite;
|
||||
}
|
||||
|
||||
bool IsSpecialType(schema::PrimitiveType type) {
|
||||
if ((type == schema::PrimitiveType_TupleGetItem) || (type == schema::PrimitiveType_Depend) ||
|
||||
(type == schema::PrimitiveType_ControlDepend) ||
|
||||
(type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
|
||||
MS_ASSERT(nullptr != tensor);
|
||||
std::vector<int> shape(tensor->shape());
|
||||
|
@ -363,12 +372,7 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
|
|||
return false;
|
||||
}
|
||||
auto type = GetCNodeType(cnode);
|
||||
|
||||
if ((type == schema::PrimitiveType_TupleGetItem) ||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
(type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) ||
|
||||
#endif
|
||||
(type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) {
|
||||
if (IsSpecialType(type)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<lite::Tensor *> input_tensors;
|
||||
|
|
|
@ -147,7 +147,7 @@ int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) {
|
|||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
|
||||
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_);
|
||||
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_, train_flag_);
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope();
|
||||
lite::NoSupportOp::GetInstance()->InsertOp(primitive->name());
|
||||
|
|
|
@ -33,6 +33,7 @@ class MindirAdjustPass : public Pass {
|
|||
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
|
||||
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
|
||||
int ValueNodeInt64Convert(AnfNodePtr anf_node);
|
||||
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
|
||||
int ParameterNodeConvert(AnfNodePtr anf_node);
|
||||
int PrimitiveConvert(AnfNodePtr anf_node);
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
@ -40,6 +41,7 @@ class MindirAdjustPass : public Pass {
|
|||
protected:
|
||||
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
|
||||
FmkType fmk_type_ = FmkType::FmkType_MS;
|
||||
bool train_flag_ = false;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_
|
||||
|
|
|
@ -131,12 +131,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
|
|||
param_value->set_format(schema::Format::Format_CKHW);
|
||||
} else if (op_type == schema::PrimitiveType_DeConv2D) {
|
||||
param_value->set_format(schema::Format::Format_KCHW);
|
||||
#ifdef SUPPORT_TRAIN
|
||||
} else if (op_type == schema::PrimitiveType_Conv2DGradInput) {
|
||||
param_value->set_format(schema::Format::Format_KCHW);
|
||||
} else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) {
|
||||
param_value->set_format(schema::Format::Format_CKHW);
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
|
||||
<< ", node: " << conv_node->fullname_with_scope();
|
||||
|
@ -213,10 +211,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
|
|||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
auto type = opt::GetCNodeType(node);
|
||||
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
|
||||
#ifdef SUPPORT_TRAIN
|
||||
((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) &&
|
||||
((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) &&
|
||||
#endif
|
||||
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -43,11 +43,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
|
|||
continue;
|
||||
}
|
||||
auto type = opt::GetCNodeType(node);
|
||||
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
|
||||
#ifdef SUPPORT_TRAIN
|
||||
&& type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput
|
||||
#endif
|
||||
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
|
||||
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
|
||||
type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput &&
|
||||
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
|
||||
continue;
|
||||
}
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
|
|
Loading…
Reference in New Issue