From 05cfa6dd1c3057bd1ff06cc9f5e9bb075b4e3cac Mon Sep 17 00:00:00 2001 From: joylvliang Date: Sat, 29 May 2021 17:38:40 +0800 Subject: [PATCH] fix_bug_of_efficientnet_not_work --- .../pipeline/pynative/pynative_execute.cc | 22 ++--- .../pipeline/pynative/pynative_execute.h | 1 + .../cv/efficientnet/src/efficientnet.py | 2 - .../pynative/test_pynative_temporary_cell.py | 86 +++++++++++++++++++ 4 files changed, 98 insertions(+), 13 deletions(-) create mode 100644 tests/st/pynative/test_pynative_temporary_cell.py diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 9dec70e6777..070229e35eb 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1432,10 +1432,6 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e MS_LOG(DEBUG) << "Current op info: " << op_info; std::vector all_op_tensors; - // Get input tensors - for (size_t i = 0; i < op_exec_info->op_inputs.size(); ++i) { - TensorValueToTensor(parse::data_converter::PyDataToValue(op_exec_info->op_inputs[i]), &all_op_tensors); - } // Get output tensors TensorValueToTensor(parse::data_converter::PyDataToValue(out_real), &all_op_tensors); // Save all tensors info of current op @@ -1949,22 +1945,24 @@ void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const auto cell_id = GetCellId(cell, args); MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; if (top_cell_ != nullptr && cell_stack_.empty()) { - // non-first step + // Non-first step if (already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) { - // top cell + // Top cell forward run. const auto &pre_top_cell = already_run_top_cell_.at(cell_id); if (!pre_top_cell->is_dynamic()) { - MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic or ms_function, no need to run NewGraphInner again"; + MS_LOG(DEBUG) << "Top cell " << cell_id << " is not dynamic, no need to run NewGraphInner again"; ResetTopCellInfo(pre_top_cell, args); set_top_cell(pre_top_cell); + cached_top_cell_forward_running_ = true; return; } - } else if (top_cell()->IsSubCell(cell_id) && !top_cell()->is_dynamic()) { - // non-top cell - MS_LOG(DEBUG) << "no need to run NewGraphInner again"; + } else if (top_cell()->IsSubCell(cell_id) || cached_top_cell_forward_running_) { + // Sub cell (may be a temporary cell) forward run in cache process. + MS_LOG(DEBUG) << "No need to run NewGraphInner again"; return; } } + // When the cell has custom bprop, in_custom_bprop_cell is lager than 0 if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { custom_bprop_cell_count_ += 1; @@ -2092,6 +2090,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const MS_LOG(DEBUG) << "Current cell " << cell_id << " no need to run EndGraphInner again"; if (top_cell()->is_topest() && cell_id == top_cell()->cell_id()) { set_grad_flag(false); + cached_top_cell_forward_running_ = false; } return; } @@ -2441,7 +2440,7 @@ void GradExecutor::CheckNeedCompileGraph() { MS_LOG(DEBUG) << "Pre all op info : " << pre_all_op_info; MS_LOG(DEBUG) << "New all op info : " << new_all_op_info; if (pre_all_op_info != new_all_op_info) { - MS_LOG(DEBUG) << "The op info has been changed or new top cell has ms_function, need to compile graph again"; + MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again"; EraseTopCellFromTopCellList(pre_top_cell); pre_top_cell->clear(); already_run_top_cell_[top_cell_id] = new_top_cell; @@ -2663,6 +2662,7 @@ void GradExecutor::ClearRes() { grad_flag_ = false; need_renormalize_ = false; grad_is_running_ = false; + cached_top_cell_forward_running_ = false; top_cell_ = nullptr; curr_g_ = nullptr; bprop_cell_list_.clear(); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 3caf3092501..daf6cfa3079 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -266,6 +266,7 @@ class GradExecutor { bool grad_flag_{false}; bool need_renormalize_{false}; bool grad_is_running_{false}; + bool cached_top_cell_forward_running_{false}; int custom_bprop_cell_count_{0}; size_t grad_order_{0}; diff --git a/model_zoo/official/cv/efficientnet/src/efficientnet.py b/model_zoo/official/cv/efficientnet/src/efficientnet.py index 8b43ad25d9c..835761d1367 100644 --- a/model_zoo/official/cv/efficientnet/src/efficientnet.py +++ b/model_zoo/official/cv/efficientnet/src/efficientnet.py @@ -20,7 +20,6 @@ from copy import deepcopy import mindspore as ms import mindspore.nn as nn -from mindspore import ms_function from mindspore.common.initializer import (Normal, One, Uniform, Zero) from mindspore.ops import operations as P from mindspore.ops.composite import clip_by_value @@ -346,7 +345,6 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): return arch_args -@ms_function def hard_swish(x): x = P.Cast()(x, ms.float32) y = x + 3.0 diff --git a/tests/st/pynative/test_pynative_temporary_cell.py b/tests/st/pynative/test_pynative_temporary_cell.py new file mode 100644 index 00000000000..817c23daa30 --- /dev/null +++ b/tests/st/pynative/test_pynative_temporary_cell.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.ops as P +from mindspore.nn.optim import Momentum +from mindspore.common import ParameterTuple + + +class GradofParams(nn.Cell): + def __init__(self, net, sens=False): + super().__init__() + self.grad = P.GradOperation(get_all=False, get_by_list=True, sens_param=sens) + self.net = net + self.params = ParameterTuple(self.net.trainable_params()) + + def construct(self, *x): + out = self.grad(self.net, self.params)(*x) + return out + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pynative_temporary_cell_variables(): + context.set_context(mode=context.PYNATIVE_MODE) + + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.Add() + self.conv = nn.Conv2d(1, 1, 3, weight_init='ones', pad_mode='pad') + self.relu = nn.ReLU() + + def construct(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.add(x, x) + return x + + class TempCellNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.Add() + self.conv = nn.Conv2d(1, 1, 3, weight_init='ones', pad_mode='pad') + + def construct(self, x): + x = self.conv(x) + x = nn.ReLU()(x) + x = self.add(x, x) + return x + + input_data = Tensor(np.random.randn(1, 1, 224, 224).astype(np.float32)) + # The first net run + net = Net() + backnet = GradofParams(net) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + grad_first = backnet(input_data) + optimizer(grad_first) + grad_second = backnet(input_data) + # The second net run + compare_net = TempCellNet() + compare_backnet = GradofParams(compare_net) + compare_optimizer = Momentum(filter(lambda x: x.requires_grad, compare_net.get_parameters()), 0.1, 0.9) + compare_grad_first = compare_backnet(input_data) + compare_optimizer(compare_grad_first) + compare_grad_second = compare_backnet(input_data) + # compare result + assert np.allclose(grad_first[0].asnumpy(), compare_grad_first[0].asnumpy(), 0.01, 0.01) + assert np.allclose(grad_second[0].asnumpy(), compare_grad_second[0].asnumpy(), 0.01, 0.01)