forked from mindspore-Ecosystem/mindspore
!20986 revert the modification of ExpandJPrim
Merge pull request !20986 from huangbingjian/revert_jprim
This commit is contained in:
commit
b32bfb0111
|
@ -65,28 +65,37 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &r
|
|||
}
|
||||
} // namespace internal
|
||||
|
||||
bool ExpandJPrim::operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
|
||||
bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
// Search all j nodes.
|
||||
GetJPrim(func_graph);
|
||||
// Get j nodes that don't have embed j nodes.
|
||||
std::vector<CNodePtr> todo;
|
||||
// If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
|
||||
// ExpandJ innermost graph or primitive first.
|
||||
std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo),
|
||||
[](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); });
|
||||
// Expand j nodes that don't have embed j nodes.
|
||||
bool change = false;
|
||||
auto manager = optimizer->manager();
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
auto j_node = node->cast<CNodePtr>();
|
||||
// If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph.
|
||||
// ExpandJ innermost graph or primitive first.
|
||||
if (internal::CheckIfEmbedJ(j_node)) {
|
||||
continue;
|
||||
}
|
||||
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource());
|
||||
manager->Replace(j_node, expanded_j);
|
||||
change = true;
|
||||
}
|
||||
for (auto &j_node : todo) {
|
||||
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer->resource());
|
||||
manager->Replace(j_node, expanded_j);
|
||||
change = true;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
void ExpandJPrim::GetJPrim(const FuncGraphPtr &func_graph) {
|
||||
j_nodes_.clear();
|
||||
AnfNodePtr ret = func_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
j_nodes_.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,11 @@ class ExpandJPrim {
|
|||
public:
|
||||
ExpandJPrim() = default;
|
||||
virtual ~ExpandJPrim() = default;
|
||||
bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer);
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
|
||||
void GetJPrim(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
std::vector<CNodePtr> j_nodes_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops.composite import GradOperation
|
||||
from mindspore.common import ParameterTuple
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
self.sens_param = self.grad.sens_param
|
||||
self.wrt_params = wrt_params
|
||||
self.real_inputs_count = real_inputs_count
|
||||
if self.wrt_params:
|
||||
self.params = ParameterTuple(self.network.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
if self.real_inputs_count is None or self.sens_param is False:
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*inputs)
|
||||
return self.grad(self.network)(*inputs)
|
||||
|
||||
real_inputs = inputs[:self.real_inputs_count]
|
||||
sense_param_inputs = inputs[self.real_inputs_count:]
|
||||
if self.wrt_params:
|
||||
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
|
||||
return self.grad(self.network)(*real_inputs, sense_param_inputs)
|
||||
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
|
||||
def __init__(self, network, sens_param=True, real_inputs_count=None):
|
||||
super().__init__(grad=GradOperation(sens_param=sens_param),
|
||||
network=network, real_inputs_count=real_inputs_count)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.mul = ops.Mul()
|
||||
self.add = ops.TensorAdd()
|
||||
weight_np = np.array([2]).astype(np.float32)
|
||||
bias_np = np.array([1]).astype(np.float32)
|
||||
self.weight = Parameter(Tensor(weight_np),
|
||||
name='weight', requires_grad=True)
|
||||
self.bias = Parameter(Tensor(bias_np),
|
||||
name="bias", requires_grad=True)
|
||||
|
||||
def construct(self, x):
|
||||
xw = self.mul(x, self.weight)
|
||||
output = self.add(xw, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class WithLossCellLocal(nn.Cell):
|
||||
def __init__(self, grad, loss):
|
||||
super(WithLossCellLocal, self).__init__(auto_prefix=False)
|
||||
self.grad = grad
|
||||
self.loss = loss
|
||||
|
||||
def construct(self, data, label):
|
||||
out = self.grad(data)
|
||||
return self.loss(out, label)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_high_grad_train():
|
||||
x_pure = np.random.randint(-10, 100, 32)
|
||||
x_train = x_pure.astype(np.float32)
|
||||
y_noise = 3 * x_pure + 2 + np.random.randn(32) / 10
|
||||
y_train = y_noise.astype(np.float32)
|
||||
net = Net()
|
||||
grad_net = GradOfFirstInput(net, sens_param=False)
|
||||
epoch = 2
|
||||
momentum = 0.0
|
||||
learning_rate = 0.001
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad,
|
||||
grad_net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.loss.MSELoss()
|
||||
net_with_criterion = WithLossCellLocal(grad_net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer)
|
||||
train_network.set_train()
|
||||
for i in range(epoch):
|
||||
train_network(Tensor([x_train[i]]), Tensor([y_train[i]]))
|
Loading…
Reference in New Issue