diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc index cf0aa3ad929..81bd78a0f16 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -65,28 +65,37 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r } } // namespace internal -bool ExpandJPrim::operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - AnfNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - +bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { + // Search all j nodes. + GetJPrim(func_graph); + // Get j nodes that don't have embed j nodes. + std::vector todo; + // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. + // ExpandJ innermost graph or primitive first. + std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo), + [](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); }); + // Expand j nodes that don't have embed j nodes. bool change = false; auto manager = optimizer->manager(); - for (auto &node : all_nodes) { - if (IsPrimitiveCNode(node, prim::kPrimJ)) { - auto j_node = node->cast(); - // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. - // ExpandJ innermost graph or primitive first. - if (internal::CheckIfEmbedJ(j_node)) { - continue; - } - auto expanded_j = internal::ExpandJ(j_node->input(1)->cast(), optimizer->resource()); - manager->Replace(j_node, expanded_j); - change = true; - } + for (auto &j_node : todo) { + auto expanded_j = internal::ExpandJ(j_node->input(1)->cast(), optimizer->resource()); + manager->Replace(j_node, expanded_j); + change = true; } return change; } + +void ExpandJPrim::GetJPrim(const FuncGraphPtr &func_graph) { + j_nodes_.clear(); + AnfNodePtr ret = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + for (auto &node : all_nodes) { + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + j_nodes_.push_back(node->cast()); + } + } +} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h index 1ab2f2d703a..1e435be52de 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -36,7 +36,11 @@ class ExpandJPrim { public: ExpandJPrim() = default; virtual ~ExpandJPrim() = default; - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer); + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); + void GetJPrim(const FuncGraphPtr &func_graph); + + private: + std::vector j_nodes_; }; } // namespace irpass } // namespace opt diff --git a/tests/st/high_grad/test_highgrad_train.py b/tests/st/high_grad/test_highgrad_train.py new file mode 100644 index 00000000000..06e782e1435 --- /dev/null +++ b/tests/st/high_grad/test_highgrad_train.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import context, Tensor, Parameter +from mindspore.nn import TrainOneStepCell +from mindspore.nn.optim import Momentum +from mindspore.ops.composite import GradOperation +from mindspore.common import ParameterTuple + +context.set_context(mode=context.GRAPH_MODE) + + +class _Grad(nn.Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.real_inputs_count is None or self.sens_param is False: + if self.wrt_params: + return self.grad(self.network, self.params)(*inputs) + return self.grad(self.network)(*inputs) + + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + if self.wrt_params: + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfFirstInput(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = ops.Mul() + self.add = ops.TensorAdd() + weight_np = np.array([2]).astype(np.float32) + bias_np = np.array([1]).astype(np.float32) + self.weight = Parameter(Tensor(weight_np), + name='weight', requires_grad=True) + self.bias = Parameter(Tensor(bias_np), + name="bias", requires_grad=True) + + def construct(self, x): + xw = self.mul(x, self.weight) + output = self.add(xw, self.bias) + return output + + +class WithLossCellLocal(nn.Cell): + def __init__(self, grad, loss): + super(WithLossCellLocal, self).__init__(auto_prefix=False) + self.grad = grad + self.loss = loss + + def construct(self, data, label): + out = self.grad(data) + return self.loss(out, label) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_high_grad_train(): + x_pure = np.random.randint(-10, 100, 32) + x_train = x_pure.astype(np.float32) + y_noise = 3 * x_pure + 2 + np.random.randn(32) / 10 + y_train = y_noise.astype(np.float32) + net = Net() + grad_net = GradOfFirstInput(net, sens_param=False) + epoch = 2 + momentum = 0.0 + learning_rate = 0.001 + optimizer = Momentum(filter(lambda x: x.requires_grad, + grad_net.get_parameters()), learning_rate, momentum) + criterion = nn.loss.MSELoss() + net_with_criterion = WithLossCellLocal(grad_net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) + train_network.set_train() + for i in range(epoch): + train_network(Tensor([x_train[i]]), Tensor([y_train[i]]))