train model

This commit is contained in:
yefeng 2021-08-23 21:27:20 +08:00
parent b9ec533f95
commit a77f23643d
5 changed files with 16 additions and 10 deletions

View File

@ -25,14 +25,11 @@ int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
const TensorC *in = inputs[1];
if (inputs[0]->format_ != Format_NHWC || in->format_ != Format_NHWC) {
if ((inputs[0]->shape_size_ == 4 && inputs[0]->format_ != Format_NHWC) ||
(in->shape_size_ == 4 && in->format_ != Format_NHWC)) {
return NNACL_FORMAT_ERROR;
}
const TensorC *scale = inputs[2];
if (in->shape_size_ != 4) {
return NNACL_INPUT_TENSOR_ERROR;
}
SetShapeTensor(outputs[0], in);
SetDataTypeFormat(outputs[0], in);
SetShapeTensor(outputs[1], scale);

View File

@ -41,8 +41,15 @@ constexpr int kMaxTaskNum = 4;
int BNGradCPUKernel::ReSize() {
auto *input_x = in_tensors_.at(1);
int channels = input_x->shape().at(kNHWC_C);
ws_size_ = kWsMultiplier * channels;
if (input_x->shape().size() == 4) {
int channels = input_x->shape().at(kNHWC_C);
ws_size_ = kWsMultiplier * channels;
} else if (input_x->shape().size() == 2) {
int channels = input_x->shape().at(1);
ws_size_ = kWsMultiplier * channels;
} else {
MS_LOG(ERROR) << "not support input dims: " << input_x->shape().size();
}
set_workspace_size(ws_size_ * sizeof(float));
return RET_OK;
}

View File

@ -231,7 +231,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel));
if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
}

View File

@ -347,7 +347,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
status = ReplaceTupleGetItem(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimDropout)) {
if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
status = RemoveDropoutOp(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {

View File

@ -26,7 +26,8 @@ using mindspore::converter::FmkType;
namespace mindspore::opt {
class RemoveRedundantOpPass : public Pass {
public:
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
explicit RemoveRedundantOpPass(bool is_train_model)
: Pass("remove_redundant_op_pass"), is_train_model_(is_train_model) {}
~RemoveRedundantOpPass() override = default;
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
@ -37,6 +38,7 @@ class RemoveRedundantOpPass : public Pass {
bool Run(const FuncGraphPtr &graph) override;
private:
bool is_train_model_ = false;
std::set<AnfNodePtr> remove_cnode_;
};
} // namespace mindspore::opt