forked from mindspore-Ecosystem/mindspore
train model
This commit is contained in:
parent
b9ec533f95
commit
a77f23643d
|
@ -25,14 +25,11 @@ int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
}
|
||||
|
||||
const TensorC *in = inputs[1];
|
||||
if (inputs[0]->format_ != Format_NHWC || in->format_ != Format_NHWC) {
|
||||
if ((inputs[0]->shape_size_ == 4 && inputs[0]->format_ != Format_NHWC) ||
|
||||
(in->shape_size_ == 4 && in->format_ != Format_NHWC)) {
|
||||
return NNACL_FORMAT_ERROR;
|
||||
}
|
||||
const TensorC *scale = inputs[2];
|
||||
if (in->shape_size_ != 4) {
|
||||
return NNACL_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
||||
SetShapeTensor(outputs[0], in);
|
||||
SetDataTypeFormat(outputs[0], in);
|
||||
SetShapeTensor(outputs[1], scale);
|
||||
|
|
|
@ -41,8 +41,15 @@ constexpr int kMaxTaskNum = 4;
|
|||
|
||||
int BNGradCPUKernel::ReSize() {
|
||||
auto *input_x = in_tensors_.at(1);
|
||||
int channels = input_x->shape().at(kNHWC_C);
|
||||
ws_size_ = kWsMultiplier * channels;
|
||||
if (input_x->shape().size() == 4) {
|
||||
int channels = input_x->shape().at(kNHWC_C);
|
||||
ws_size_ = kWsMultiplier * channels;
|
||||
} else if (input_x->shape().size() == 2) {
|
||||
int channels = input_x->shape().at(1);
|
||||
ws_size_ = kWsMultiplier * channels;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "not support input dims: " << input_x->shape().size();
|
||||
}
|
||||
set_workspace_size(ws_size_ * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -231,7 +231,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
|
|||
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
|
||||
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
|
||||
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel));
|
||||
if (!config->trainModel) {
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
|
||||
}
|
||||
|
|
|
@ -347,7 +347,7 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
status = ReplaceTupleGetItem(node, manager);
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimDropout)) {
|
||||
if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
|
||||
status = RemoveDropoutOp(node, manager);
|
||||
}
|
||||
if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
|
||||
|
|
|
@ -26,7 +26,8 @@ using mindspore::converter::FmkType;
|
|||
namespace mindspore::opt {
|
||||
class RemoveRedundantOpPass : public Pass {
|
||||
public:
|
||||
RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {}
|
||||
explicit RemoveRedundantOpPass(bool is_train_model)
|
||||
: Pass("remove_redundant_op_pass"), is_train_model_(is_train_model) {}
|
||||
~RemoveRedundantOpPass() override = default;
|
||||
int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||
int ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node);
|
||||
|
@ -37,6 +38,7 @@ class RemoveRedundantOpPass : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
bool is_train_model_ = false;
|
||||
std::set<AnfNodePtr> remove_cnode_;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
|
|
Loading…
Reference in New Issue