forked from mindspore-Ecosystem/mindspore
[ME][Auto_monad]The load node in print operator inputs should not be replaced.
This commit is contained in:
parent
6ebee5ff2c
commit
7b73c6408b
|
@ -57,16 +57,11 @@ std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
|
|||
return ref_key->name();
|
||||
}
|
||||
|
||||
bool HasMemoryEffect(const CNodePtr &cnode) {
|
||||
bool HasSideEffect(const CNodePtr &cnode) {
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (HasAbstractUMonad(inputs.back())) {
|
||||
// The last input is UMonad.
|
||||
return true;
|
||||
}
|
||||
constexpr size_t kRequiredArgs = 2;
|
||||
if (inputs.size() > kRequiredArgs) {
|
||||
// The last two inputs are UMonad and IOMonad.
|
||||
return HasAbstractIOMonad(inputs.back()) && HasAbstractUMonad(inputs.rbegin()[1]);
|
||||
return HasAbstractMonad(inputs.back());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -111,8 +106,8 @@ LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNod
|
|||
continue;
|
||||
}
|
||||
// Record param user in toposort nodes.
|
||||
// We only check memory side effect cnodes or Depend nodes.
|
||||
if (HasMemoryEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
|
||||
// We only check side effect cnodes or Depend nodes.
|
||||
if (HasSideEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
|
||||
for (size_t n = 1; n < cnode->size(); ++n) {
|
||||
const auto &input = cnode->input(n);
|
||||
auto ref_key = GetRefKey(input);
|
||||
|
|
|
@ -1518,3 +1518,47 @@ def test_multi_abs_add_assign():
|
|||
outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy(), tmp.asnumpy()]
|
||||
expects = numpy_out(p, i0, i1, i2)
|
||||
np.testing.assert_array_equal(outputs, expects)
|
||||
|
||||
|
||||
@security_off_wrap
|
||||
def test_print_assign_print():
|
||||
"""
|
||||
Feature: Auto Monad
|
||||
Description: Test load eliminate when umonad and iomona both exist.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Print(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.print = P.Print()
|
||||
self.assign = P.Assign()
|
||||
self.param = Parameter(Tensor(1, dtype=ms.int32), name='param')
|
||||
|
||||
def func(self):
|
||||
self.assign(self.param, self.param * 5)
|
||||
return self.param + 5
|
||||
|
||||
def construct(self, value):
|
||||
param = self.param
|
||||
self.print("param_1:", param)
|
||||
res = self.func()
|
||||
self.print("res:", res)
|
||||
self.print("param_2:", param)
|
||||
self.param = value
|
||||
self.print("param_3:", param)
|
||||
return res
|
||||
|
||||
cap = Capture()
|
||||
with capture(cap):
|
||||
input_x = Tensor(3, dtype=ms.int32)
|
||||
expect = Tensor(10, dtype=ms.int32)
|
||||
net = Print()
|
||||
out = net(input_x)
|
||||
time.sleep(0.1)
|
||||
|
||||
patterns = {'param_1:\nTensor(shape=[], dtype=Int32, value=1)\n'
|
||||
'res:\nTensor(shape=[], dtype=Int32, value=10)\n'
|
||||
'param_2:\nTensor(shape=[], dtype=Int32, value=5)\n'
|
||||
'param_3:\nTensor(shape=[], dtype=Int32, value=3)'}
|
||||
check_output(cap.output, patterns)
|
||||
np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue