forked from mindspore-Ecosystem/mindspore
!21624 [MS][LITE] lite train support mix precision model convert
Merge pull request !21624 from zhengjun10/cast_transpose
This commit is contained in:
commit
00a5180c35
|
@ -46,6 +46,6 @@ if [[ ! -z ${QUANTIZE} ]]; then
|
|||
QUANT_OPTIONS="--configFile=${WEIGHT_QUANT_CONFIG}"
|
||||
fi
|
||||
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=lenet_tod.mindir --outputFile=lenet_tod $QUANT_OPTIONS
|
||||
if [ -n "$3" ]; then
|
||||
if [[ ! -z ${MIX_FLAG} ]]; then
|
||||
LD_LIBRARY_PATH=./:${LD_LIBRARY_PATH} $CONVERTER --fmk=MINDIR --trainModel=true --modelFile=mix_lenet_tod.mindir --outputFile=mix_lenet_tod
|
||||
fi
|
||||
|
|
|
@ -104,7 +104,7 @@ fi
|
|||
|
||||
cd model/ || exit 1
|
||||
rm -f *.ms
|
||||
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} ./prepare_model.sh $BATCH $DOCKER $MIX_FLAG || exit 1
|
||||
EXPORT=${EXPORT} QUANTIZE=${QUANTIZE} MIX_FLAG=${MIX_FLAG} ./prepare_model.sh $BATCH $DOCKER || exit 1
|
||||
cd ../
|
||||
|
||||
# Copy the .ms model to the package folder
|
||||
|
|
|
@ -257,7 +257,7 @@ void NetRunner::Usage() {
|
|||
|
||||
bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
||||
int opt;
|
||||
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vob:")) != -1) {
|
||||
while ((opt = getopt(argc, argv, "f:e:d:s:ihc:vmob:")) != -1) {
|
||||
switch (opt) {
|
||||
case 'f':
|
||||
ms_file_ = std::string(optarg);
|
||||
|
@ -280,8 +280,8 @@ bool NetRunner::ReadArgs(int argc, char *argv[]) {
|
|||
case 'b':
|
||||
virtual_batch_ = atoi(optarg);
|
||||
break;
|
||||
case 'r':
|
||||
is_raw_mix_precision_ = atoi(optarg);
|
||||
case 'm':
|
||||
is_raw_mix_precision_ = true;
|
||||
break;
|
||||
case 'h':
|
||||
default:
|
||||
|
|
|
@ -27,7 +27,7 @@ class LiteSession;
|
|||
class TrainLoop;
|
||||
|
||||
struct TrainLoopCallBackData {
|
||||
TrainLoopCallBackData(bool train_mode, int epoch, LiteSession *session, TrainLoop *loop)
|
||||
TrainLoopCallBackData(bool train_mode, unsigned int epoch, LiteSession *session, TrainLoop *loop)
|
||||
: train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
|
||||
|
||||
bool train_mode_; /**< training mode of LiteSession object */
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement or dataset is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
auto loop = std::unique_ptr<session::TrainLoop>(session::TrainLoop::CreateTrainLoop((impl_->session_).get()));
|
||||
|
@ -67,8 +67,8 @@ Status Model::Train(int epochs, std::shared_ptr<dataset::Dataset> ds, std::vecto
|
|||
}
|
||||
|
||||
Status Model::Evaluate(std::shared_ptr<dataset::Dataset> ds, std::vector<TrainCallBack *> i_cbs) {
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr)) {
|
||||
MS_LOG(ERROR) << "Model implement is null.";
|
||||
if ((impl_ == nullptr) || (impl_->session_ == nullptr) || ds == nullptr) {
|
||||
MS_LOG(ERROR) << "Model implement or dataset is null.";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,6 +47,10 @@ Status ModelImpl::PrepareMetrics(Model *model, std::vector<session::Metrics *> *
|
|||
}
|
||||
auto model_metrics = GetMetrics();
|
||||
for (auto m : model_metrics) {
|
||||
if (m == nullptr) {
|
||||
MS_LOG(ERROR) << "Null input metrics";
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
if (m->metrics_impl_) {
|
||||
// For off-the-shelf metrics it is guaranteed that we have also an MSLite implementation
|
||||
auto internal_m = m->metrics_impl_->GetInternalMetrics();
|
||||
|
@ -79,6 +83,9 @@ Status ModelImpl::ConvertCallbacks(Model *model, std::vector<TrainCallBack *> *i
|
|||
return kLiteUninitializedObj;
|
||||
}
|
||||
for (auto cb : *i_cbs) {
|
||||
if (cb == nullptr) {
|
||||
return kLiteUninitializedObj;
|
||||
}
|
||||
if (cb->callback_impl_) {
|
||||
// For off-the-shelf callback it is guaranteed that we have also an MSLite implementation
|
||||
auto internal_cb = cb->callback_impl_->GetInternalCallback();
|
||||
|
|
|
@ -91,6 +91,10 @@ class OptimizerKernel : public InnerKernel {
|
|||
indices.push_back(lr_idx_);
|
||||
for (size_t ix = 0; ix < indices.size(); ix++) {
|
||||
if (param->tensor_name() == in_tensors_.at(indices[ix])->tensor_name()) {
|
||||
if (param->Size() != in_tensors_.at(indices[ix])->Size()) {
|
||||
MS_LOG(ERROR) << "Tensor: " << param->tensor_name() << "set size not same";
|
||||
return false;
|
||||
}
|
||||
auto value = static_cast<float *>(param->MutableData())[0];
|
||||
static_cast<float *>(in_tensors_.at(indices[ix])->MutableData())[0] = value;
|
||||
if (lr_idx_ == indices[ix]) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "include/dataset/iterator.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -35,14 +36,24 @@ using session::RET_STOP_TRAINING;
|
|||
TrainLoop::~TrainLoop() {}
|
||||
|
||||
int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func) {
|
||||
train_session_->Train();
|
||||
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
|
||||
auto ret = train_session_->Train();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainLoop train failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_CHECK_GT(epochs, 0, RET_ERROR);
|
||||
session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this);
|
||||
|
||||
if (load_func == nullptr) load_func = TrainLoop::LoadData;
|
||||
|
||||
for (auto cb : cbs) cb->Begin(cb_data);
|
||||
for (auto cb : cbs) {
|
||||
MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
|
||||
cb->Begin(cb_data);
|
||||
}
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
|
||||
for (int i = 0; i < epochs; i++) {
|
||||
cb_data.epoch_ = epoch_++;
|
||||
for (auto cb : cbs) cb->EpochBegin(cb_data);
|
||||
|
@ -51,10 +62,9 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
int s = 0;
|
||||
|
||||
iter->GetNextRow(&row_vec);
|
||||
while (row_vec.size() != 0) {
|
||||
auto ret = load_func(cb_data.session_->GetInputs(), &row_vec);
|
||||
while (!row_vec.empty()) {
|
||||
ret = load_func(cb_data.session_->GetInputs(), &row_vec);
|
||||
if (ret != RET_OK) break;
|
||||
|
||||
cb_data.step_ = s++;
|
||||
for (auto cb : cbs) cb->StepBegin(cb_data);
|
||||
|
||||
|
@ -64,7 +74,7 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
}
|
||||
int break_loop = false;
|
||||
for (auto cb : cbs) {
|
||||
int ret = cb->EpochEnd(cb_data);
|
||||
ret = cb->EpochEnd(cb_data);
|
||||
if (ret != RET_CONTINUE) {
|
||||
if (ret == RET_EXIT) {
|
||||
MS_LOG(ERROR) << "Error in TrainLoop callback";
|
||||
|
@ -85,23 +95,35 @@ int TrainLoop::Train(int epochs, Dataset *ds, std::vector<session::TrainLoopCall
|
|||
}
|
||||
|
||||
int TrainLoop::Eval(Dataset *ds, std::vector<session::TrainLoopCallBack *> cbs, LoadDataFunc load_func, int max_steps) {
|
||||
train_session_->Eval();
|
||||
MS_CHECK_TRUE_MSG(train_session_ != nullptr && ds != nullptr, RET_ERROR, "graph data cannot be nullptr");
|
||||
auto ret = train_session_->Eval();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TrainLoop train failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
session::TrainLoopCallBackData cb_data(false, epoch_, train_session_, this);
|
||||
|
||||
if (load_func == nullptr) load_func = TrainLoop::LoadData;
|
||||
|
||||
for (auto metric : metrics_) metric->Clear();
|
||||
for (auto cb : cbs) cb->Begin(cb_data);
|
||||
for (auto metric : metrics_) {
|
||||
MS_CHECK_TRUE_MSG(metric != nullptr, RET_ERROR, "metric cannot be nullptr");
|
||||
metric->Clear();
|
||||
}
|
||||
for (auto cb : cbs) {
|
||||
MS_CHECK_TRUE_MSG(cb != nullptr, RET_ERROR, "callback cannot be nullptr");
|
||||
cb->Begin(cb_data);
|
||||
}
|
||||
for (auto cb : cbs) cb->EpochBegin(cb_data);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
MS_CHECK_TRUE_MSG(iter != nullptr, RET_ERROR, "iterator cannot be nullptr");
|
||||
MSTensorVec row_vec;
|
||||
int s = 0;
|
||||
|
||||
iter->GetNextRow(&row_vec);
|
||||
while (row_vec.size() != 0) {
|
||||
while (!row_vec.empty()) {
|
||||
if (s >= max_steps) break;
|
||||
auto ret = load_func(cb_data.session_->GetInputs(), &row_vec);
|
||||
ret = load_func(cb_data.session_->GetInputs(), &row_vec);
|
||||
if (ret != RET_OK) break;
|
||||
|
||||
cb_data.step_ = ++s;
|
||||
|
|
|
@ -63,6 +63,10 @@ int TrainSession::Init(InnerContext *context, const TrainCfg *train_cfg) {
|
|||
}
|
||||
cfg_ = *train_cfg;
|
||||
}
|
||||
if (context == nullptr) {
|
||||
MS_LOG(ERROR) << "context cannot be nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
allocator_ = context->allocator;
|
||||
return lite::LiteSession::Init(context);
|
||||
}
|
||||
|
|
|
@ -141,7 +141,7 @@ Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
|
|||
std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
|
||||
std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat};
|
||||
if (!IsContain(valid_type, tensor->data_type())) {
|
||||
MS_LOG(ERROR) << "source data type must be fp32 or fp16";
|
||||
MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -139,6 +139,18 @@ TEST_F(TestCxxApiLiteModel, test_getparams_SUCCESS) {
|
|||
for (size_t ix = 0; ix < params1.size(); ix++) {
|
||||
ASSERT_EQ(static_cast<float *>(params1[ix].MutableData())[0], static_cast<float>(ix) + pi);
|
||||
}
|
||||
if (!params.empty()) {
|
||||
auto ¶m = params.at(0);
|
||||
param.SetShape({20, 20});
|
||||
param.SetDataType(DataType::kNumberTypeInt8);
|
||||
}
|
||||
ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess);
|
||||
|
||||
if (!params.empty()) {
|
||||
auto ¶m = params.at(0);
|
||||
param.SetTensorName("failed_name");
|
||||
}
|
||||
ASSERT_TRUE(model.SetOptimizerParams(params) != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
|
||||
|
@ -159,5 +171,62 @@ TEST_F(TestCxxApiLiteModel, test_getgrads_SUCCESS) {
|
|||
static_cast<float *>(graients[ix].MutableData())[0] = static_cast<float>(ix) + pi;
|
||||
}
|
||||
ASSERT_TRUE(model.ApplyGradients(graients) == kSuccess);
|
||||
if (!graients.empty()) {
|
||||
auto ¶m = graients.at(0);
|
||||
param.SetShape({20, 20});
|
||||
}
|
||||
|
||||
ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess);
|
||||
if (!graients.empty()) {
|
||||
auto ¶m = graients.at(0);
|
||||
param.SetTensorName("failed_name");
|
||||
}
|
||||
ASSERT_TRUE(model.ApplyGradients(graients) != kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteModel, test_fp32_SUCCESS) {
|
||||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
auto train_cfg = std::make_shared<TrainCfg>();
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||
|
||||
ASSERT_TRUE(Serialization::Load("./nets/conv_train_model.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false;
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
cpu_context->SetEnableFP16(false);
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
}
|
||||
|
||||
TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
|
||||
Model model;
|
||||
Graph graph;
|
||||
auto context = std::make_shared<Context>();
|
||||
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||
cpu_context->SetEnableFP16(true);
|
||||
context->MutableDeviceInfo().push_back(cpu_context);
|
||||
auto train_cfg = std::make_shared<TrainCfg>();
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||
|
||||
ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = false;
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
cpu_context->SetEnableFP16(false);
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
|
||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Binary file not shown.
|
@ -115,7 +115,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
|
||||
}
|
||||
if (config->fmk == converter::kFmkTypeMs) {
|
||||
if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
|
||||
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
|
||||
if (remove_unused_cast_pass == nullptr) {
|
||||
MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
|
||||
|
|
|
@ -14,30 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
|
||||
#include "tools/converter/parser/unused_node_remove_pass.h"
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/parser/conv1d_inout_adjust.h"
|
||||
#include "tools/converter/parser/inputs_adjust.h"
|
||||
#include "ops/transpose.h"
|
||||
#include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
|
||||
#include "tools/converter/parser/unused_node_remove_pass.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/format/to_format_base.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kNumWeightIndex = 2;
|
||||
bool IsWeightNodeSensitive(const AnfNodePtr &node) {
|
||||
return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimAdam);
|
||||
}
|
||||
std::unordered_map<std::string, size_t> weight_indexs = {{ops::kNameConv2DFusion, 2},
|
||||
{ops::kNameConv2DBackpropInputFusion, 2},
|
||||
{ops::kNameConv2dTransposeFusion, 2},
|
||||
{ops::kNameApplyMomentum, 1},
|
||||
{ops::kNameSGD, 1},
|
||||
{ops::kNameAdam, 1}};
|
||||
} // namespace
|
||||
|
||||
void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
|
||||
|
@ -146,15 +146,9 @@ int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index) {
|
||||
MS_ASSERT(graph != nullptr && cnode != nullptr);
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
MS_LOG(ERROR) << "cnode is not a member of convolution's family.";
|
||||
return nullptr;
|
||||
}
|
||||
auto weight_node = cnode->input(opt::kInputIndexTwo);
|
||||
auto weight_node = cnode->input(index);
|
||||
bool is_real_weight =
|
||||
!opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) && !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
|
||||
while (!is_real_weight) {
|
||||
|
@ -169,7 +163,7 @@ AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnod
|
|||
}
|
||||
auto manager = Manage(graph);
|
||||
MS_ASSERT(manager != nullptr);
|
||||
manager->Replace(cnode->input(opt::kInputIndexTwo), weight_node);
|
||||
manager->Replace(cnode->input(index), weight_node);
|
||||
return weight_node;
|
||||
}
|
||||
|
||||
|
@ -179,18 +173,19 @@ int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, sche
|
|||
if (src_format == dst_format) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) &&
|
||||
!opt::CheckPrimitiveType(cnode, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!opt::CheckPrimitiveType(cnode, prim::kPrimConv2dTransposeFusion)) {
|
||||
MS_LOG(ERROR) << "cnode is not a member of convolution's family.";
|
||||
auto primitive_ptr = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto primitive_name = primitive_ptr->name();
|
||||
if (weight_indexs.find(primitive_name) == weight_indexs.end()) {
|
||||
MS_LOG(ERROR) << primitive_name << " is not a member of convolution's family.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (GetRealConvWeightNode(graph, cnode) == nullptr) {
|
||||
size_t index = weight_indexs[primitive_name];
|
||||
if (GetRealConvWeightNode(graph, cnode, index) == nullptr) {
|
||||
MS_LOG(ERROR) << "current conv node is invalid, node name is " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
bool is_const_weight = true;
|
||||
auto weight_node = cnode->input(opt::kInputIndexTwo);
|
||||
auto weight_node = cnode->input(index);
|
||||
if (utils::isa<CNode>(weight_node)) {
|
||||
is_const_weight = false;
|
||||
} else if (utils::isa<Parameter>(weight_node)) {
|
||||
|
@ -234,7 +229,7 @@ int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_
|
|||
MS_LOG(ERROR) << "post node is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsWeightNodeSensitive(post_node)) {
|
||||
if (!opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
|
||||
continue;
|
||||
}
|
||||
has_visited->insert(post_node);
|
||||
|
@ -285,6 +280,9 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
|
|||
MS_LOG(ERROR) << "conv weight is non-const.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (weight_value->shape().size() != kShape4dDims) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto status = opt::TransFilterFormat(weight_value, src_format, dst_format);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(src_format) << "To" << EnumNameFormat(dst_format)
|
||||
|
@ -328,7 +326,7 @@ int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &wei
|
|||
MS_LOG(ERROR) << "post node is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (IsWeightNodeSensitive(post_node)) {
|
||||
if (opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
|
||||
has_visited->insert(post_node);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all
|
|||
int CommonAnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
|
||||
int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm);
|
||||
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode);
|
||||
AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
|
||||
int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, schema::Format src_format,
|
||||
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
||||
int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ops/batch_norm.h"
|
||||
#include "ops/batch_to_space.h"
|
||||
#include "ops/bias_add.h"
|
||||
#include "ops/cast.h"
|
||||
#include "ops/concat.h"
|
||||
#include "ops/crop.h"
|
||||
#include "ops/depth_to_space.h"
|
||||
|
@ -102,12 +103,12 @@ static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{
|
|||
|
||||
// a certain op whose input's format is not fixed, bool value determines whether the op has axis attribute or not.
|
||||
static const std::unordered_map<std::string, bool> DynamicFormatOpList = {
|
||||
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
|
||||
{ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
|
||||
{ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
|
||||
{ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
|
||||
{ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
|
||||
{ops::kNameQuantDTypeCast, false}};
|
||||
{ops::kNameAddN, false}, {ops::kNameCrop, true}, {ops::kNameSplit, true},
|
||||
{ops::kNameConcat, true}, {ops::kNameEltwise, false}, {ops::kNameMaximum, false},
|
||||
{ops::kNameAddFusion, false}, {ops::kNameDivFusion, false}, {ops::kNameMulFusion, false},
|
||||
{ops::kNamePadFusion, false}, {ops::kNamePowFusion, false}, {ops::kNameActivation, false},
|
||||
{ops::kNameSliceFusion, true}, {ops::kNameStridedSlice, true}, {ops::kNameActivationGrad, false},
|
||||
{ops::kNameQuantDTypeCast, false}, {ops::kNameCast, false}};
|
||||
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNHWCOpMap() { return NHWCOpMap; }
|
||||
const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap() { return NCHWOpMap; }
|
||||
|
|
|
@ -368,9 +368,7 @@ STATUS ToFormatBase::ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<A
|
|||
}
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(node, prim::kPrimConv2DFusion) &&
|
||||
!CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) &&
|
||||
!CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) {
|
||||
if (!IsWeightNodeSensitive(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (has_visited->find(node) != has_visited->end()) {
|
||||
|
|
|
@ -26,6 +26,11 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/graph/infershape_pass.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "ops/adam.h"
|
||||
#include "ops/sgd.h"
|
||||
#include "ops/apply_momentum.h"
|
||||
|
||||
using mindspore::converter::FmkType;
|
||||
namespace mindspore {
|
||||
|
@ -37,6 +42,16 @@ class ToFormatBase : public Pass {
|
|||
: Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {}
|
||||
~ToFormatBase() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
static bool IsConvFamilyNode(const AnfNodePtr &node) {
|
||||
return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) ||
|
||||
opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion);
|
||||
}
|
||||
static bool IsOptimizerNode(const AnfNodePtr &node) {
|
||||
return opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) ||
|
||||
opt::CheckPrimitiveType(node, prim::kPrimAdam);
|
||||
}
|
||||
static bool IsWeightNodeSensitive(const AnfNodePtr &node) { return IsConvFamilyNode(node) || IsOptimizerNode(node); }
|
||||
|
||||
private:
|
||||
bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph);
|
||||
|
|
|
@ -366,8 +366,10 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const
|
|||
MS_ASSERT(pre_type != nullptr && post_type != nullptr);
|
||||
size_t trans_count = 0;
|
||||
std::vector<AnfNodePtr> in_nodes;
|
||||
auto graph_inputs = func_graph->get_inputs();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (utils::isa<CNodePtr>(cnode->input(i))) {
|
||||
if (utils::isa<CNodePtr>(cnode->input(i)) ||
|
||||
std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
|
||||
in_nodes.push_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue