!21624 [MS][LITE] lite train support mix precision model convert

Merge pull request !21624 from zhengjun10/cast_transpose
This commit is contained in:
i-robot 2021-09-07 06:33:05 +00:00 committed by Gitee
commit 00a5180c35
19 changed files with 185 additions and 65 deletions

View File

@ -46,6 +46,6 @@ if [[ ! -z ${QUANTIZE} ]]; then
QUANT_OPTIONS="--configFile=${WEIGHT_QUANT_CONFIG}"
fi
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
if [ -n "$3" ]; then
if [[ ! -z ${MIX_FLAG} ]]; then
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=mix_lenet_tod.mindir --outputFile=mix_lenet_tod
fi

View File

@ -104,7 +104,7 @@ fi
cd model/ || exit 1
rm -f *.ms
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER $MIX_FLAG || exit 1
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} MIX_FLAG=${MIX_FLAG} ./prepare_model.sh $BATCH $DOCKER || exit 1
cd ../
# Copy the .ms model to the package folder

View File

@ -257,7 +257,7 @@ void NetRunner::Usage() {
bool NetRunner::ReadArgs(int argc, char *argv[]) {
int opt;
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vob:")) != -1) {
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vmob:")) != -1) {
switch (opt) {
case 'f':
ms_file_ = std::string(optarg);
@ -280,8 +280,8 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) {
case 'b':
virtual_batch_ = atoi(optarg);
break;
case 'r':
is_raw_mix_precision_ = atoi(optarg);
case 'm':
is_raw_mix_precision_ = true;
break;
case 'h':
default:

View File

@ -27,7 +27,7 @@ class LiteSession;
class TrainLoop;
struct TrainLoopCallBackData {
TrainLoopCallBackData(bool train_mode, int epoch, LiteSession *session, TrainLoop *loop)
TrainLoopCallBackData(bool train_mode, unsigned int epoch, LiteSession *session, TrainLoop *loop)
: train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
bool train_mode_; /**< training mode of LiteSession object */

View File

@ -28,8 +28,8 @@
namespace mindspore {
Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model implement is null.";
if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) {
MS_LOG(ERROR) << "Model implement or dataset is null.";
return kLiteUninitializedObj;
}
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
@ -67,8 +67,8 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
}
Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
MS_LOG(ERROR) << "Model implement is null.";
if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) {
MS_LOG(ERROR) << "Model implement or dataset is null.";
return kLiteUninitializedObj;
}

View File

@ -47,6 +47,10 @@ Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *
}
auto model_metrics = GetMetrics();
for (auto m : model_metrics) {
if (m == nullptr) {
MS_LOG(ERROR) << "Null input metrics";
return kLiteUninitializedObj;
}
if (m->metrics_impl_) {
// For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
auto internal_m = m->metrics_impl_->GetInternalMetrics();
@ -79,6 +83,9 @@ Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i
return kLiteUninitializedObj;
}
for (auto cb : *i_cbs) {
if (cb == nullptr) {
return kLiteUninitializedObj;
}
if (cb->callback_impl_) {
// For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
auto internal_cb = cb->callback_impl_->GetInternalCallback();

View File

@ -91,6 +91,10 @@ class OptimizerKernel : public InnerKernel {
indices.push_back(lr_idx_);
for (size_t ix = 0; ix < indices.size(); ix++) {
if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name()) {
if (param->Size() != in_tensors_.at(indices[ix])->Size()) {
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "set size not same";
return false;
}
auto value = static_cast<float *>(param->MutableData())[0];
static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
if (lr_idx_ == indices[ix]) {

View File

@ -22,6 +22,7 @@
#include "include/errorcode.h"
#include "include/dataset/iterator.h"
#include "src/common/log_adapter.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@ -35,14 +36,24 @@ using session::RET_STOP_TRAINING;
TrainLoop::~TrainLoop() {}
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
train_session_->Train();
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
auto ret = train_session_->Train();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TrainLoop train failed";
return RET_ERROR;
}
MS_CHECK_GT(epochs, 0, RET_ERROR);
session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
if (load_func == nullptr) load_func = TrainLoop::LoadData;
for (auto cb : cbs) cb->Begin(cb_data);
for (auto cb : cbs) {
MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
cb->Begin(cb_data);
}
std::shared_ptr<Iterator> iter = ds->CreateIterator();
MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
for (int i = 0; i < epochs; i++) {
cb_data.epoch_ = epoch_++;
for (auto cb : cbs) cb->EpochBegin(cb_data);
@ -51,10 +62,9 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
int s = 0;
iter->GetNextRow(&row_vec);
while (row_vec.size() != 0) {
auto ret = load_func(cb_data.session_->GetInputs(), &row_vec);
while (!row_vec.empty()) {
ret = load_func(cb_data.session_->GetInputs(), &row_vec);
if (ret != RET_OK) break;
cb_data.step_ = s++;
for (auto cb : cbs) cb->StepBegin(cb_data);
@ -64,7 +74,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
}
int break_loop = false;
for (auto cb : cbs) {
int ret = cb->EpochEnd(cb_data);
ret = cb->EpochEnd(cb_data);
if (ret != RET_CONTINUE) {
if (ret == RET_EXIT) {
MS_LOG(ERROR) << "Error in TrainLoop callback";
@ -85,23 +95,35 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
}
int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
train_session_->Eval();
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
auto ret = train_session_->Eval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TrainLoop train failed";
return RET_ERROR;
}
session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
if (load_func == nullptr) load_func = TrainLoop::LoadData;
for (auto metric : metrics_) metric->Clear();
for (auto cb : cbs) cb->Begin(cb_data);
for (auto metric : metrics_) {
MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
metric->Clear();
}
for (auto cb : cbs) {
MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
cb->Begin(cb_data);
}
for (auto cb : cbs) cb->EpochBegin(cb_data);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
MSTensorVec row_vec;
int s = 0;
iter->GetNextRow(&row_vec);
while (row_vec.size() != 0) {
while (!row_vec.empty()) {
if (s >= max_steps) break;
auto ret = load_func(cb_data.session_->GetInputs(), &row_vec);
ret = load_func(cb_data.session_->GetInputs(), &row_vec);
if (ret != RET_OK) break;
cb_data.step_ = ++s;

View File

@ -63,6 +63,10 @@ int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
}
cfg_ = *train_cfg;
}
if (context == nullptr) {
MS_LOG(ERROR) << "context cannot be nullptr";
return RET_NULL_PTR;
}
allocator_ = context->allocator;
return lite::LiteSession::Init(context);
}

View File

@ -141,7 +141,7 @@ Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat};
if (!IsContain(valid_type, tensor->data_type())) {
MS_LOG(ERROR) << "source data type must be fp32 or fp16";
MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type();
return nullptr;
}

View File

@ -139,6 +139,18 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
for (size_t ix = 0; ix < params1.size(); ix++) {
ASSERT_EQ(static_cast<float *>(params1[ix].MutableData())[0], static_cast<float>(ix) + pi);
}
if (!params.empty()) {
auto &param = params.at(0);
param.SetShape({20, 20});
param.SetDataType(DataType::kNumberTypeInt8);
}
ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess);
if (!params.empty()) {
auto &param = params.at(0);
param.SetTensorName("failed_name");
}
ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess);
}
TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
@ -159,5 +171,62 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
static_cast<float *>(graients[ix].MutableData())[0] = static_cast<float>(ix) + pi;
}
ASSERT_TRUE(model.ApplyGradients(graients) == kSuccess);
if (!graients.empty()) {
auto &param = graients.at(0);
param.SetShape({20, 20});
}
ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess);
if (!graients.empty()) {
auto &param = graients.at(0);
param.SetTensorName("failed_name");
}
ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess);
}
TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false;
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
cpu_context->SetEnableFP16(false);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
}
TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false;
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
cpu_context->SetEnableFP16(false);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
}
} // namespace mindspore

View File

@ -115,7 +115,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
}
if (config->fmk == converter::kFmkTypeMs) {
if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";

View File

@ -14,30 +14,30 @@
* limitations under the License.
*/
#include "tools/converter/parser/parser_utils.h"
#include <memory>
#include <algorithm>
#include <vector>
#include <memory>
#include <set>
#include <string>
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
#include "tools/converter/parser/unused_node_remove_pass.h"
#include <vector>
#include <unordered_map>
#include "ops/transpose.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/conv1d_inout_adjust.h"
#include "tools/converter/parser/inputs_adjust.h"
#include "ops/transpose.h"
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
#include "tools/converter/parser/unused_node_remove_pass.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/common/tensor_util.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/format/to_format_base.h"
namespace mindspore::lite {
namespace {
constexpr size_t kNumWeightIndex = 2;
bool IsWeightNodeSensitive(const AnfNodePtr &node) {
return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) ||
opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
opt::CheckPrimitiveType(node, prim::kPrimAdam);
}
std::unordered_map<std::string, size_t> weight_indexs = {{ops::kNameConv2DFusion, 2},
{ops::kNameConv2DBackpropInputFusion, 2},
{ops::kNameConv2dTransposeFusion, 2},
{ops::kNameApplyMomentum, 1},
{ops::kNameSGD, 1},
{ops::kNameAdam, 1}};
} // namespace
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
@ -146,15 +146,9 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format
return lite::RET_OK;
}
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index) {
MS_ASSERT(graph != nullptr && cnode != nullptr);
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
MS_LOG(ERROR) << "cnode is not a member of convolution's family.";
return nullptr;
}
auto weight_node = cnode->input(opt::kInputIndexTwo);
auto weight_node = cnode->input(index);
bool is_real_weight =
!opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) && !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
while (!is_real_weight) {
@ -169,7 +163,7 @@ AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnod
}
auto manager = Manage(graph);
MS_ASSERT(manager != nullptr);
manager->Replace(cnode->input(opt::kInputIndexTwo), weight_node);
manager->Replace(cnode->input(index), weight_node);
return weight_node;
}
@ -179,18 +173,19 @@ int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, sche
if (src_format == dst_format) {
return lite::RET_OK;
}
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) &&
!opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) &&
!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
MS_LOG(ERROR) << "cnode is not a member of convolution's family.";
auto primitive_ptr = GetValueNode<PrimitivePtr>(cnode->input(0));
auto primitive_name = primitive_ptr->name();
if (weight_indexs.find(primitive_name) == weight_indexs.end()) {
MS_LOG(ERROR) << primitive_name << " is not a member of convolution's family.";
return RET_ERROR;
}
if (GetRealConvWeightNode(graph, cnode) == nullptr) {
size_t index = weight_indexs[primitive_name];
if (GetRealConvWeightNode(graph, cnode, index) == nullptr) {
MS_LOG(ERROR) << "current conv node is invalid, node name is " << cnode->fullname_with_scope();
return RET_ERROR;
}
bool is_const_weight = true;
auto weight_node = cnode->input(opt::kInputIndexTwo);
auto weight_node = cnode->input(index);
if (utils::isa<CNode>(weight_node)) {
is_const_weight = false;
} else if (utils::isa<Parameter>(weight_node)) {
@ -234,7 +229,7 @@ int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_
MS_LOG(ERROR) << "post node is invalid.";
return RET_ERROR;
}
if (!IsWeightNodeSensitive(post_node)) {
if (!opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
continue;
}
has_visited->insert(post_node);
@ -285,6 +280,9 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
MS_LOG(ERROR) << "conv weight is non-const.";
return RET_ERROR;
}
if (weight_value->shape().size() != kShape4dDims) {
return lite::RET_OK;
}
auto status = opt::TransFilterFormat(weight_value, src_format, dst_format);
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(src_format) << "To" << EnumNameFormat(dst_format)
@ -328,7 +326,7 @@ int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &wei
MS_LOG(ERROR) << "post node is invalid.";
return RET_ERROR;
}
if (IsWeightNodeSensitive(post_node)) {
if (opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
has_visited->insert(post_node);
continue;
}

View File

@ -30,7 +30,7 @@ void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all
int CommonAnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode);
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, schema::Format src_format,
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,

View File

@ -25,6 +25,7 @@
#include "ops/batch_norm.h"
#include "ops/batch_to_space.h"
#include "ops/bias_add.h"
#include "ops/cast.h"
#include "ops/concat.h"
#include "ops/crop.h"
#include "ops/depth_to_space.h"
@ -102,12 +103,12 @@ static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{
// a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
{ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
{ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
{ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
{ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
{ops::kNameQuantDTypeCast, false}};
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
{ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
{ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
{ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
{ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
{ops::kNameQuantDTypeCast, false}, {ops::kNameCast, false}};
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }

View File

@ -368,9 +368,7 @@ STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<A
}
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
!CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
if (!IsWeightNodeSensitive(cnode)) {
continue;
}
if (has_visited->find(node) != has_visited->end()) {

View File

@ -26,6 +26,11 @@
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/optimizer/graph/infershape_pass.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/adam.h"
#include "ops/sgd.h"
#include "ops/apply_momentum.h"
using mindspore::converter::FmkType;
namespace mindspore {
@ -37,6 +42,16 @@ class ToFormatBase : public Pass {
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
~ToFormatBase() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
static bool IsConvFamilyNode(const AnfNodePtr &node) {
return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion);
}
static bool IsOptimizerNode(const AnfNodePtr &node) {
return opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
opt::CheckPrimitiveType(node, prim::kPrimAdam);
}
static bool IsWeightNodeSensitive(const AnfNodePtr &node) { return IsConvFamilyNode(node) || IsOptimizerNode(node); }
private:
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);

View File

@ -366,8 +366,10 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const
MS_ASSERT(pre_type != nullptr && post_type != nullptr);
size_t trans_count = 0;
std::vector<AnfNodePtr> in_nodes;
auto graph_inputs = func_graph->get_inputs();
for (size_t i = 1; i < cnode->size(); ++i) {
if (utils::isa<CNodePtr>(cnode->input(i))) {
if (utils::isa<CNodePtr>(cnode->input(i)) ||
std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
in_nodes.push_back(cnode->input(i));
}
}