forked from mindspore-Ecosystem/mindspore
!25761 Fix switch layer recursive bug
Merge pull request !25761 from chenfei_mindspore/switch_layer_reursive_fix
This commit is contained in:
commit
0a4cc28c9d
|
@ -1957,6 +1957,19 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool IsUnusedInternlOutput(const AnfNodePtr &user) {
|
||||
if (!CNodeFirstInputIsPrimitive(user)) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
|
||||
return true;
|
||||
}
|
||||
if (!AnfAlgo::IsRealKernel(user)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
constexpr auto kMixTarget = "MixTarget";
|
||||
|
@ -2040,11 +2053,7 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
if (!CNodeFirstInputIsPrimitive(user)) {
|
||||
internal_output = false;
|
||||
break;
|
||||
}
|
||||
if (!AnfAlgo::IsRealKernel(user)) {
|
||||
if (IsUnusedInternlOutput(user)) {
|
||||
internal_output = false;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,19 @@ import pytest
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.composite import GradOperation
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.grad = GradOperation(get_all=False)
|
||||
self.net = net
|
||||
|
||||
def construct(self, x, y):
|
||||
grad_net = self.grad(self.net)
|
||||
grad = grad_net(x, y)
|
||||
return grad
|
||||
|
||||
|
||||
class CaseNet(nn.Cell):
|
||||
|
@ -53,3 +66,51 @@ def test_switch_layer():
|
|||
true_value = relu(data)
|
||||
ret = np.allclose(value.asnumpy(), true_value.asnumpy())
|
||||
assert ret
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_in_list():
|
||||
"""
|
||||
Feature: Switch layer in while.
|
||||
Description: test recursive switch layer.
|
||||
Expectation: success if grad and output are correct.
|
||||
"""
|
||||
|
||||
class TestCell(nn.Cell):
|
||||
def __init__(self, i):
|
||||
super().__init__()
|
||||
self.i = i
|
||||
|
||||
def construct(self, x):
|
||||
return self.i * x
|
||||
|
||||
class CellInList(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
self.cell_list.append(TestCell(4))
|
||||
self.cell_list.append(TestCell(5))
|
||||
self.cell_list.append(TestCell(6))
|
||||
|
||||
def construct(self, t, x):
|
||||
out = t
|
||||
while x < 3:
|
||||
add = self.cell_list[x](t)
|
||||
out = out + add
|
||||
x += 1
|
||||
return out
|
||||
|
||||
net = CellInList()
|
||||
t = Tensor(10, mstype.int32)
|
||||
x = Tensor(0, mstype.int32)
|
||||
out = net(t, x)
|
||||
grad_net = Grad(net)
|
||||
grad_out = grad_net(t, x)
|
||||
|
||||
assert out == Tensor(160, mstype.int32)
|
||||
assert grad_out == Tensor(16, mstype.int32)
|
||||
|
|
Loading…
Reference in New Issue