!12289 [MSLITE] remove support train in convert

From: @zhengjun10
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-20 16:14:25 +08:00 committed by Gitee
commit 148da5223d
22 changed files with 241 additions and 334 deletions

View File

@ -114,16 +114,6 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
max_dim = dim;
}
}
#ifndef SUPPORT_TRAIN
for (size_t i = 0; i < inputs.size(); ++i) {
size_t shift = max_dims - inputs.at(i)->shape().size();
size_t dim = (i < shift) ? 1 : inputs.at(i)->shape().at(d);
if ((dim != max_dim) && (dim != 1)) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
}
#endif
output->shape()[d] = max_dim; // set the biggest dimension in the output tensor
}

View File

@ -149,13 +149,13 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
attr->padRight = pad_list.at(3);
auto dilation = CastToInt(prim.GetAttr("dilation"));
#ifdef SUPPORT_TRAIN
attr->dilateH = dilation.at(2);
attr->dilateW = dilation.at(3);
#else
attr->dilateH = dilation.at(0);
attr->dilateW = dilation.at(1);
#endif
if (train_flag()) {
attr->dilateH = dilation.at(2);
attr->dilateW = dilation.at(3);
} else {
attr->dilateH = dilation.at(0);
attr->dilateW = dilation.at(1);
}
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size.at(0);
attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0);

View File

@ -19,9 +19,6 @@
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
#ifdef SUPPORT_TRAIN
#include <tuple>
#endif
namespace mindspore {
namespace lite {
@ -56,20 +53,16 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
}
string paddingmode = "REFLECT";
if (prim.GetAttr("mode") == nullptr) {
#ifdef SUPPORT_TRAIN
if (prim.name() == "Pad") {
paddingmode = "CONSTANT";
} else {
#endif
MS_LOG(ERROR) << "get mode failed!";
delete this->primitive_;
delete attr;
this->primitive_ = nullptr;
attr = nullptr;
return RET_ERROR;
#ifdef SUPPORT_TRAIN
}
#endif
} else {
paddingmode = GetValue<string>(prim.GetAttr("mode"));
}
@ -77,7 +70,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
attr->paddingMode = schema::PaddingMode_REFLECT;
} else if (paddingmode == "SYMMETRIC") {
attr->paddingMode = schema::PaddingMode_SYMMETRIC;
#ifdef SUPPORT_TRAIN
} else if (paddingmode == "CONSTANT") {
attr->paddingMode = schema::PaddingMode_CONSTANT;
if (prim.GetAttr("paddings") != nullptr) {
@ -91,7 +83,6 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
attr->paddings.push_back(i);
}
}
#endif
} else {
MS_LOG(ERROR) << "model type not supported!";
delete this->primitive_;

View File

@ -18,7 +18,6 @@
#ifdef PRIMITIVE_WRITEABLE
#include <memory>
#include <map>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/ops/assert_op.h"
#include "src/ops/space_to_batch.h"
@ -175,8 +174,6 @@
#include "src/ops/uniform_real.h"
#include "src/ops/rank.h"
#include "src/ops/is_finite.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
#include "src/ops/activation_grad.h"
#include "src/ops/apply_momentum.h"
@ -210,7 +207,6 @@
#include "src/ops/sigmoid_cross_entropy_with_logits_grad.h"
#include "src/ops/strided_slice_grad.h"
#endif
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -513,13 +509,14 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
const schema::QuantType &quantType, bool train_flag = false) {
auto primc = std::make_shared<T>();
if (primc == nullptr) {
MS_LOG(ERROR) << "make_shared PrimitiveC failed";
return nullptr;
}
primc->set_quant_type(quantType);
primc->set_train_flag(train_flag);
auto ret = primc->UnPackAttr(prim, inputs);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UnPackAttr failed";
@ -529,7 +526,7 @@ std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, cons
}
std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
const schema::QuantType &quantType, bool train_flag) {
const auto &op_type = prim.name();
if (op_type == "ReLU" || op_type == "ReLU6" || op_type == "Sigmoid" || op_type == "HSwish" || op_type == "HSigmoid") {
return NewPrimitiveC<Activation>(prim, inputs, quantType);
@ -544,7 +541,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
} else if (op_type == "Concat") {
return NewPrimitiveC<Concat>(prim, inputs, quantType);
} else if (op_type == "Conv2D") {
return NewPrimitiveC<Conv2D>(prim, inputs, quantType);
return NewPrimitiveC<Conv2D>(prim, inputs, quantType, train_flag);
} else if (op_type == "Cos") {
return NewPrimitiveC<Cos>(prim, inputs, quantType);
} else if (op_type == "DepthwiseConv2dNative" || op_type == "DepthwiseConv2D") {
@ -664,7 +661,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
} else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
return NewPrimitiveC<Tile>(prim, inputs, quantType, train_flag);
} else if (op_type == "GatherNd") {
return NewPrimitiveC<GatherNd>(prim, inputs, quantType);
} else if (op_type == "Square") {
@ -685,7 +682,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<ArgMax>(prim, inputs, quantType);
} else if (op_type == "Gelu") {
return NewPrimitiveC<GeLU>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
} else if (op_type == "SparseSoftmaxCrossEntropyWithLogits") {
@ -706,7 +702,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropFilter") {
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropInput") {
} else if (op_type == "Conv2DBackpropInput" && train_flag) {
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
} else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
@ -748,10 +744,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<StridedSliceGrad>(prim, inputs, quantType);
} else if (op_type == "AbsGrad") {
return NewPrimitiveC<AbsGrad>(prim, inputs, quantType);
#else
} else if (op_type == "Conv2DBackpropInput") {
} else if (op_type == "Conv2DBackpropInput" && !train_flag) {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
#endif
} else {
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << op_type;
return nullptr;
@ -1065,7 +1059,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) UniformReal(primitive);
case schema::PrimitiveType_Rank:
return new (std::nothrow) Rank(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);
case schema::PrimitiveType_PoolingGrad:
@ -1140,7 +1133,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) SigmoidCrossEntropyWithLogitsGrad(primitive);
case schema::PrimitiveType_StridedSliceGrad:
return new (std::nothrow) StridedSliceGrad(primitive);
#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);
break;
@ -1170,6 +1162,10 @@ bool PrimitiveC::infer_flag() const { return this->infer_flag_; }
void PrimitiveC::set_infer_flag(bool flag) { this->infer_flag_ = flag; }
bool PrimitiveC::train_flag() const { return this->train_flag_; }
void PrimitiveC::set_train_flag(bool flag) { this->train_flag_ = flag; }
int PrimitiveC::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
auto input = inputs.front();
MS_ASSERT(input != nullptr);

View File

@ -133,6 +133,10 @@ class PrimitiveC : public mindspore::Primitive {
void set_infer_flag(bool flag);
bool train_flag() const;
void set_train_flag(bool flag);
static PrimitiveC *Create(mindspore::schema::Primitive *primitive) { return Create(primitive->UnPack()); }
static PrimitiveC *Create(mindspore::schema::PrimitiveT *primitive);
@ -140,7 +144,7 @@ class PrimitiveC : public mindspore::Primitive {
static void GetAttrDataFromInput(const AnfNodePtr &inputNode, std::vector<int> *data);
static std::shared_ptr<PrimitiveC> Create(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType);
const schema::QuantType &quantType, bool train_flag = false);
void PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void FillDefaultInputQuantParamIfNeed(const size_t &inputSize);
void PopulaterInputQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
@ -159,6 +163,7 @@ class PrimitiveC : public mindspore::Primitive {
bool infer_flag_ = true;
int op_type_ = OP_TYPE_NOT_SET;
bool enable_huffman_code_ = false;
bool train_flag_ = false;
};
std::shared_ptr<PrimitiveC> GetReturnPrim();
@ -179,6 +184,10 @@ class PrimitiveC {
void set_infer_flag(bool flag);
bool train_flag() const;
void set_train_flag(bool flag);
virtual int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs);
int Type() const;
@ -238,6 +247,7 @@ class PrimitiveC {
bool infer_flag_ = true;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
int op_type_ = OP_TYPE_NOT_SET;
bool train_flag_ = false;
};
using PrimitiveCPtr = std::shared_ptr<PrimitiveC>;
typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive);

View File

@ -159,41 +159,41 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
} else {
multiples = GetMultiples();
}
#ifdef SUPPORT_TRAIN
const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size();
if (train_flag()) {
const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size();
size_t i = 0;
for (; i < delta_dims; ++i) {
int tmp = input->shape().at(i);
out_shape.push_back(tmp);
}
for (; i < in_dims; ++i) {
int tmp = input->shape().at(i) * (multiples[i - delta_dims]);
out_shape.push_back(tmp);
}
#else
std::vector<int> dims = GetDims();
if (inputs_.size() == 2 && dims.empty()) {
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
dims.push_back(dim);
size_t i = 0;
for (; i < delta_dims; ++i) {
int tmp = input->shape().at(i);
out_shape.push_back(tmp);
}
for (; i < in_dims; ++i) {
int tmp = input->shape().at(i) * (multiples[i - delta_dims]);
out_shape.push_back(tmp);
}
} else {
std::vector<int> dims = GetDims();
if (inputs_.size() == 2 && dims.empty()) {
for (int dim = 0; dim < inputs_[1]->ElementsNum(); ++dim) {
dims.push_back(dim);
}
}
const size_t in_dims = input->shape().size();
MS_ASSERT(multiples.size() == dims.size());
for (size_t i = 0; i < in_dims; ++i) {
out_shape.push_back(input->shape().at(i));
}
for (size_t i = 0; i < dims.size(); ++i) {
if (input->shape().at(dims.at(i)) != 0 &&
multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) {
MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big";
return RET_ERROR;
}
out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i));
}
}
const size_t in_dims = input->shape().size();
MS_ASSERT(multiples.size() == dims.size());
for (size_t i = 0; i < in_dims; ++i) {
out_shape.push_back(input->shape().at(i));
}
for (size_t i = 0; i < dims.size(); ++i) {
if (input->shape().at(dims.at(i)) != 0 &&
multiples.at(i) > std::numeric_limits<int>::max() / input->shape().at(dims.at(i))) {
MS_LOG(ERROR) << "The value of multiples[" << i << "] is too big";
return RET_ERROR;
}
out_shape.at(dims.at(i)) = input->shape().at(dims.at(i)) * (multiples.at(i));
}
#endif
output->set_shape(out_shape);
return RET_OK;
}

View File

@ -328,15 +328,15 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
}
RemoveIfMakeTuple(cnode);
#ifdef SUPPORT_TRAIN
RemoveIfDepend(cnode);
#endif
if (train_flag) {
RemoveIfDepend(cnode);
if (primitive_c->Type() == schema::PrimitiveType_Depend ||
primitive_c->Type() == schema::PrimitiveType_ControlDepend) {
continue;
}
}
if ((primitive_c->Type() == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(primitive_c->Type() == schema::PrimitiveType_Depend) ||
(primitive_c->Type() == schema::PrimitiveType_ControlDepend) ||
#endif
(primitive_c->Type() == schema::PrimitiveType_MakeTuple)) {
continue;
}
@ -424,8 +424,10 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu
return RET_OK;
}
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
bool train_flag) {
static int subgraph_index = 0;
this->train_flag = train_flag;
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
int ret = ExportSubgraph(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
if (ret != RET_OK) {
@ -439,24 +441,18 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
std::string input_name = input_anode->fullname_with_scope();
auto input_cnode = utils::cast<CNodePtr>(input_anode);
if (!IsPrimitiveCNode(input_cnode, schema::PrimitiveType_TupleGetItem)) {
#ifndef SUPPORT_TRAIN
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
}
#else
bool found = false;
if (node_id_map_.find(input_name) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_name]);
found = true;
}
if (found == false) {
if (!found) {
auto input_index_key = input_name + "_o:" + std::to_string(0);
if (node_id_map_.find(input_index_key) != node_id_map_.end()) {
output_cnode->inputIndex.emplace_back(node_id_map_[input_index_key]);
}
}
#endif
} else {
auto inputs = input_cnode->inputs();
@ -481,17 +477,12 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode,
: GetValue<int>(value_node->value()));
auto iter = node_id_map_.find(input_index_key);
if (iter == node_id_map_.end()) {
#ifdef SUPPORT_TRAIN
input_index_key = get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(0); // try name with 0
iter = node_id_map_.find(input_index_key);
if (iter == node_id_map_.end()) {
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
return RET_ERROR;
}
#else
MS_LOG(ERROR) << "Can not find get_item output tensor " << input_index_key;
return RET_ERROR;
#endif
}
output_cnode->inputIndex.emplace_back(iter->second);
}
@ -571,9 +562,7 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr<sc
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims),
[](const int64_t &value) { return static_cast<int32_t>(value); });
(*paramTensor)->dims = dims;
#ifdef SUPPORT_TRAIN
if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1};
#endif
if (train_flag && (*paramTensor)->dims.empty()) (*paramTensor)->dims = {1};
(*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode;
auto data = value->cast<tensor::TensorPtr>();
(*paramTensor)->data.resize(data->Size());
@ -679,11 +668,11 @@ int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::uniqu
(*paramTensor)->format = schema::Format(valueLite->format());
(*paramTensor)->dataType = valueLite->tensor_type();
(*paramTensor)->dims = valueLite->tensor_shape();
#ifdef SUPPORT_TRAIN
if ((*paramTensor)->dims.size() == 0) {
if (train_flag && (*paramTensor)->dims.empty()) {
(*paramTensor)->dims = {1};
}
#endif
ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(),
valueLite->tensor_size());
if (ret != EOK) {
@ -703,9 +692,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
auto paramTensor = std::make_unique<schema::TensorT>();
auto value = valueNode->value();
int ret = RET_OK;
#ifdef SUPPORT_TRAIN
paramTensor->name = valueNode->fullname_with_scope();
#endif
if (train_flag) {
paramTensor->name = valueNode->fullname_with_scope();
}
if (value->isa<tensor::Tensor>()) {
ret = ProcessTensor(valueNode, &paramTensor, value, output_cnode, meta_graphT);
} else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) {
@ -797,44 +786,44 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
}
msTensor->nodeType = schema::NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
#ifdef SUPPORT_TRAIN
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
break;
#else
if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
msTensor->name = cnode_name;
} else {
if (train_flag) {
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
msTensor->name = name;
}
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_Adam))
break;
} else {
if (elements.size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size();
msTensor->name = cnode_name;
} else {
std::string name = cnode_name + "_o:" + std::to_string(i);
node_id_map_[name] = meta_graphT->allTensors.size();
msTensor->name = name;
}
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
delete (msTensor);
return;
if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
MS_LOG(ERROR) << "abstract is not AbstractTensor";
delete (msTensor);
return;
}
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
msTensor->dataType = type;
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
break;
}
}
auto type = kNumberTypeFloat32;
if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
auto typePtr = abstract_tensor->element()->GetTypeTrack();
type = typePtr->type_id();
}
msTensor->dataType = type;
meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
break;
}
#endif
}
} else {
auto ms_tensor = new (std::nothrow) schema::TensorT();
@ -927,8 +916,8 @@ CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, AnfNodePtr node
}
}
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive) {
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) {
AnfExporter anf_exporter;
return anf_exporter.Export(func_graph, keep_graph, copy_primitive);
return anf_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag);
}
} // namespace mindspore::lite

View File

@ -35,7 +35,8 @@ class AnfExporter {
public:
AnfExporter() = default;
virtual ~AnfExporter() = default;
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
bool train_flag = false);
void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *fb_node);
int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
@ -91,11 +92,13 @@ class AnfExporter {
std::vector<schema::CNodeT *> graph_input_nodes_;
std::map<FuncGraphPtr, int> fg_subgraph_map;
uint32_t node_idx = 0;
bool train_flag = false;
};
// by default, copy_primitive is false, which means that the MetaGraph and func_graph share the same schema::PrimitiveT.
// but in PostQuantization, the func_graph need to transfer to MetaGraph first and do MetaGraph pass, which may modify
// the schema::PrimitiveT and cause bug; If all the passes have been done in func_graph, every thing would be simple
// and clear.
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false);
schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph = false, bool copy_primitive = false,
bool train_flag = false);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_ANF_EXPORTER_H_

View File

@ -855,14 +855,14 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model
}
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
#if SUPPORT_TRAIN
func_graph_ = LoadMindIR(flag->modelFile, true);
if (func_graph_ != nullptr) {
return RET_OK;
} else {
MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format";
if (flag->trainModel) {
func_graph_ = LoadMindIR(flag->modelFile, true);
if (func_graph_ != nullptr) {
return RET_OK;
} else {
MS_LOG(ERROR) << "Parse new mind_ir proto failed, Trying old onnx format";
}
}
#endif
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
if (onnx_model_ == nullptr) {
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";

View File

@ -24,39 +24,40 @@
namespace mindspore {
namespace lite {
static const std::vector<schema::PrimitiveType> nhwcOpList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_Conv2DGradFilter,
schema::PrimitiveType_Conv2DGradInput,
schema::PrimitiveType_GroupConv2DGradInput,
schema::PrimitiveType_PoolingGrad,
schema::PrimitiveType_BiasGrad,
schema::PrimitiveType_BNGrad,
schema::PrimitiveType_ApplyMomentum,
schema::PrimitiveType_Sgd,
schema::PrimitiveType_Adam,
#endif
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_LocalResponseNormalization,
schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_PReLU,
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_SpaceToDepth,
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_TopK};
static const std::vector<schema::PrimitiveType> nhwcOpList = {schema::PrimitiveType_Conv2DGradFilter,
schema::PrimitiveType_Conv2DGradInput,
schema::PrimitiveType_GroupConv2DGradInput,
schema::PrimitiveType_PoolingGrad,
schema::PrimitiveType_BiasGrad,
schema::PrimitiveType_BNGrad,
schema::PrimitiveType_ApplyMomentum,
schema::PrimitiveType_Sgd,
schema::PrimitiveType_Adam,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_LocalResponseNormalization,
schema::PrimitiveType_Resize,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_PReLU,
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_SpaceToDepth,
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_TopK};
static const std::vector<schema::PrimitiveType> nhwcOpAllInputList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DGradFilter,
schema::PrimitiveType_BNGrad
#endif
};
schema::PrimitiveType_BNGrad};
// index {} mean all inputs need insert
static std::unordered_map<schema::PrimitiveType, std::vector<int>> extNhwcInsertIndex = {
{schema::PrimitiveType_BNGrad, {0, 1}},
{schema::PrimitiveType_ApplyMomentum, {3}},
{schema::PrimitiveType_Sgd, {1}},
{schema::PrimitiveType_Adam, {9}}};
static const std::vector<schema::PrimitiveType> fp32FullOpList = {
schema::PrimitiveType_Concat, schema::PrimitiveType_Add,
@ -133,18 +134,10 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT
schema::PrimitiveType_L2Norm};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
#ifdef SUPPORT_TRAIN
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split,
schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add,
schema::PrimitiveType_ActivationGrad
#else
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,
schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop,
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum
#endif
};
schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum, schema::PrimitiveType_ActivationGrad};
static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}};
@ -156,6 +149,8 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList() { return fp32FullOpList;
std::vector<schema::PrimitiveType> GetNhwcOpList() { return nhwcOpList; }
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes() { return extNhwcInsertIndex; }
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList() { return nhwcOpAllInputList; }
std::vector<schema::PrimitiveType> GetUint8NhwcOpList() { return int8NeedNhwcOpList; }

View File

@ -62,6 +62,8 @@ std::vector<schema::PrimitiveType> GetNhwcOpList();
std::vector<schema::PrimitiveType> GetNhwcAllInputOpList();
std::unordered_map<schema::PrimitiveType, std::vector<int>> GetExtNhwcIndexes();
std::vector<schema::PrimitiveType> Getfp32FullOpList();
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();

View File

@ -101,12 +101,6 @@ set(LITE_SRC
${SRC_DIR}/dequant.cc
${SRC_DIR}/huffman_decode.cc
)
if(SUPPORT_TRAIN)
set(LITE_SRC
${LITE_SRC}
)
endif()
set(ARM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src/runtime/kernel/arm)
file(GLOB KERNEL_SRC
${ARM_DIR}/base/*.cc

View File

@ -177,6 +177,7 @@ int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const conve
auto mindir_adjust_pass = std::make_shared<opt::MindirAdjustPass>();
mindir_adjust_pass->SetFmkType(config->fmk);
mindir_adjust_pass->SetQuantType(config->quantType);
mindir_adjust_pass->SetTrainFlag(config->trainModel);
if (!mindir_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);

View File

@ -88,7 +88,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
}
// anf -- fb
auto meta_graph = Export(graph);
auto meta_graph = Export(graph, false, false, flag->trainModel);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;

View File

@ -48,21 +48,8 @@ STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatT
FormatTransNodeType *afterNodeType) {
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
return RET_NO_CHANGE;
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*beforeNodeType = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW;
return RET_OK;
} else if (fmkType == converter::FmkType_MS) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
*beforeNodeType = kNCHW2NHWC;
*afterNodeType = kNHWC2NCHW;
return RET_OK;
} else if (fmkType == converter::FmkType_ONNX) {
} else if (fmkType == converter::FmkType_CAFFE || fmkType == converter::FmkType_MS ||
fmkType == converter::FmkType_ONNX) {
if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) {
return RET_NO_CHANGE;
}
@ -173,11 +160,19 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) {
reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC;
}
#ifdef SUPPORT_TRAIN
if (IsContain(GetNhwcAllInputOpList(), GetCNodeTType(**iter))) {
int idx_num = node->inputIndex.size();
if (GetCNodeTType(**iter) == schema::PrimitiveType_BNGrad) idx_num = 2;
for (int i = 0; i < idx_num; i++) {
auto specInsertIndexes = GetExtNhwcIndexes();
auto opType = GetCNodeTType(**iter);
if (specInsertIndexes.find(opType) != specInsertIndexes.end()) {
for (auto insert_index : specInsertIndexes[opType]) {
iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
return RET_ERROR;
}
}
} else if (IsContain(GetNhwcAllInputOpList(), opType)) {
auto input_size = node->inputIndex.size();
for (size_t i = 0; i < input_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed";
@ -185,23 +180,8 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
}
}
} else {
int idx = 0;
if (GetCNodeTType(**iter) == schema::PrimitiveType_ApplyMomentum) idx = 3;
if (GetCNodeTType(**iter) == schema::PrimitiveType_Sgd) idx = 1;
if (GetCNodeTType(**iter) == schema::PrimitiveType_Adam) idx = 9;
iter = InsertFormatTransNode(graph, iter, kBefore, idx, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR;
}
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
}
#else
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";
return RET_ERROR;
}
#endif
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed";

View File

@ -194,7 +194,6 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc
if (!IsContain(bfs_queue, input_node_index)) {
bfs_queue.emplace_back(input_node_index);
}
// todo multi output,other edge need insert nh2nc node
auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node);
if (pre_node_output_indexs.size() != 1) {
if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) {

View File

@ -29,92 +29,58 @@ std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite {
bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count,
FormatTransNodeType *trans_type) {
for (auto input_node_index : node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
MS_ASSERT(pre_node->primitive != nullptr);
MS_ASSERT(pre_node->primitive->value != nullptr);
if (*trans_type == kNONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr);
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
*trans_type = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
*trans_type = kNHWC2NCHW;
} else {
return false;
}
(*has_trans_count)++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (*trans_type != cur_type) {
return false;
} else {
(*has_trans_count)++;
}
}
}
}
return true;
}
bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr);
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
pre_type_ = kNONE;
size_t has_trans_count = 0;
auto can_fusion = true;
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
MS_ASSERT(pre_node->primitive != nullptr);
MS_ASSERT(pre_node->primitive->value != nullptr);
if (pre_type_ == kNONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
MS_ASSERT(pre_node->primitive->value.AsTranspose() != nullptr);
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
pre_type_ = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
pre_type_ = kNHWC2NCHW;
} else {
return false;
}
has_trans_count++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (pre_type_ != cur_type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
}
}
}
if (!can_fusion) {
if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) {
return false;
}
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
post_type_ = kNONE;
for (auto output_node_index : output_node_indexes) {
MS_ASSERT(graph->nodes.size() > output_node_index);
auto &post_node = graph->nodes.at(output_node_index);
MS_ASSERT(post_node != nullptr);
MS_ASSERT(post_node->primitive != nullptr);
MS_ASSERT(post_node->primitive->value != nullptr);
if (post_type_ == kNONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
post_type_ = kNCHW2NHWC;
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
post_type_ = kNHWC2NCHW;
} else {
return false;
}
has_trans_count++;
}
} else {
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (post_type_ != cur_type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
}
}
}
if (!can_fusion) {
if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) {
return false;
}
if (pre_type_ == kNONE && post_type_ == kNONE) {
@ -136,10 +102,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
if (GetCNodeTType(*node) == schema::PrimitiveType_Split) {
return has_trans_count >= half_count;
}
can_fusion = has_trans_count > half_count;
return can_fusion;
return has_trans_count > half_count;
}
STATUS TransOpInsertPass::FindOutTransType() {
pre_insert_trans_type_ = kNHWC2NCHW;
post_insert_trans_type_ = kNHWC2NCHW;
@ -153,7 +117,7 @@ STATUS TransOpInsertPass::FindOutTransType() {
MS_ASSERT(false);
} else {
if (pre_type_ == post_type_) {
MS_LOG(ERROR) << "Unknow error";
MS_LOG(ERROR) << "Unknown error";
return RET_ERROR;
}
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
@ -200,13 +164,6 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) {
STATUS status = RET_OK;
auto input_tensor_size = (*iter)->inputIndex.size();
for (size_t i = 0; i < input_tensor_size; i++) {
#ifdef SUPPORT_TRAIN
auto &tensor = graph->allTensors.at((*iter)->inputIndex[i]);
MS_ASSERT(tensor != nullptr);
if (tensor->nodeType == schema::NodeType_ValueNode) {
continue;
}
#endif
auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]);
if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) {
continue;

View File

@ -37,6 +37,15 @@ ParamValueLitePtr NewParamValueLitePtr(lite::Tensor *tensor) {
return para_value_lite;
}
bool IsSpecialType(schema::PrimitiveType type) {
if ((type == schema::PrimitiveType_TupleGetItem) || (type == schema::PrimitiveType_Depend) ||
(type == schema::PrimitiveType_ControlDepend) ||
(type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) {
return true;
}
return false;
}
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
MS_ASSERT(nullptr != tensor);
std::vector<int> shape(tensor->shape());
@ -363,12 +372,7 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
return false;
}
auto type = GetCNodeType(cnode);
if ((type == schema::PrimitiveType_TupleGetItem) ||
#ifdef SUPPORT_TRAIN
(type == schema::PrimitiveType_Depend) || (type == schema::PrimitiveType_ControlDepend) ||
#endif
(type == schema::PrimitiveType_MakeTuple || type == schema::PrimitiveType_Return)) {
if (IsSpecialType(type)) {
continue;
}
std::vector<lite::Tensor *> input_tensors;

View File

@ -147,7 +147,7 @@ int MindirAdjustPass::PrimitiveConvert(std::shared_ptr<AnfNode> anf_node) {
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
if (!CheckPrimitiveType(anf_node, prim::kPrimReturn) && !CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) {
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_);
auto primitive_c = PrimitiveC::Create(*primitive, inputs, quant_type_, train_flag_);
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "fail to create a primitive_c: " << cnode->fullname_with_scope();
lite::NoSupportOp::GetInstance()->InsertOp(primitive->name());

View File

@ -33,6 +33,7 @@ class MindirAdjustPass : public Pass {
void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; }
void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; }
int ValueNodeInt64Convert(AnfNodePtr anf_node);
void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; }
int ParameterNodeConvert(AnfNodePtr anf_node);
int PrimitiveConvert(AnfNodePtr anf_node);
bool Run(const FuncGraphPtr &graph) override;
@ -40,6 +41,7 @@ class MindirAdjustPass : public Pass {
protected:
QuantType quant_type_ = QuantType::QuantType_QUANT_NONE;
FmkType fmk_type_ = FmkType::FmkType_MS;
bool train_flag_ = false;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_

View File

@ -131,12 +131,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
param_value->set_format(schema::Format::Format_CKHW);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
#ifdef SUPPORT_TRAIN
} else if (op_type == schema::PrimitiveType_Conv2DGradInput) {
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_GroupConv2DGradInput) {
param_value->set_format(schema::Format::Format_CKHW);
#endif
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
@ -213,10 +211,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
auto conv_cnode = node->cast<CNodePtr>();
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
#ifdef SUPPORT_TRAIN
((type != schema::PrimitiveType_Conv2DGradInput) || (fmk_type != FmkType_MS)) &&
((type != schema::PrimitiveType_GroupConv2DGradInput) || (fmk_type != FmkType_MS)) &&
#endif
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}

View File

@ -43,11 +43,9 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
continue;
}
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
#ifdef SUPPORT_TRAIN
&& type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput
#endif
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
type != schema::PrimitiveType_Conv2DGradInput && type != schema::PrimitiveType_GroupConv2DGradInput &&
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();