fix cell list

This commit is contained in:
chenfei 2021-11-01 20:53:35 +08:00
parent 010cc7a435
commit d714fd4695
2 changed files with 75 additions and 5 deletions

View File

@ -1945,6 +1945,19 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
}
return nullptr;
}
bool IsUnusedInternlOutput(const AnfNodePtr &user) {
if (!CNodeFirstInputIsPrimitive(user)) {
return true;
}
if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
return true;
}
if (!AnfAlgo::IsRealKernel(user)) {
return true;
}
return false;
}
} // namespace
constexpr auto kMixTarget = "MixTarget";
@ -2028,11 +2041,7 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
continue;
}
if (!CNodeFirstInputIsPrimitive(user)) {
internal_output = false;
break;
}
if (!AnfAlgo::IsRealKernel(user)) {
if (IsUnusedInternlOutput(user)) {
internal_output = false;
break;
}

View File

@ -19,6 +19,19 @@ import pytest
import mindspore.context as context
from mindspore import Tensor, nn
from mindspore.common import dtype as mstype
from mindspore.ops.composite import GradOperation
class Grad(nn.Cell):
def __init__(self, net):
super().__init__()
self.grad = GradOperation(get_all=False)
self.net = net
def construct(self, x, y):
grad_net = self.grad(self.net)
grad = grad_net(x, y)
return grad
class CaseNet(nn.Cell):
@ -53,3 +66,51 @@ def test_switch_layer():
true_value = relu(data)
ret = np.allclose(value.asnumpy(), true_value.asnumpy())
assert ret
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_cell_in_list():
"""
Feature: Switch layer in while.
Description: test recursive switch layer.
Expectation: success if grad and output are correct.
"""
class TestCell(nn.Cell):
def __init__(self, i):
super().__init__()
self.i = i
def construct(self, x):
return self.i * x
class CellInList(nn.Cell):
def __init__(self):
super().__init__()
self.cell_list = nn.CellList()
self.cell_list.append(TestCell(4))
self.cell_list.append(TestCell(5))
self.cell_list.append(TestCell(6))
def construct(self, t, x):
out = t
while x < 3:
add = self.cell_list[x](t)
out = out + add
x += 1
return out
net = CellInList()
t = Tensor(10, mstype.int32)
x = Tensor(0, mstype.int32)
out = net(t, x)
grad_net = Grad(net)
grad_out = grad_net(t, x)
assert out == Tensor(160, mstype.int32)
assert grad_out == Tensor(16, mstype.int32)