[ME][Auto_monad]The load node in print operator inputs should not be replaced.

This commit is contained in:
Margaret_wangrui 2022-02-21 10:59:51 +08:00
parent 6ebee5ff2c
commit 7b73c6408b
2 changed files with 48 additions and 9 deletions

View File

@ -57,16 +57,11 @@ std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
return ref_key->name();
}
bool HasMemoryEffect(const CNodePtr &cnode) {
bool HasSideEffect(const CNodePtr &cnode) {
const auto &inputs = cnode->inputs();
if (HasAbstractUMonad(inputs.back())) {
// The last input is UMonad.
return true;
}
constexpr size_t kRequiredArgs = 2;
if (inputs.size() > kRequiredArgs) {
// The last two inputs are UMonad and IOMonad.
return HasAbstractIOMonad(inputs.back()) && HasAbstractUMonad(inputs.rbegin()[1]);
return HasAbstractMonad(inputs.back());
}
return false;
}
@ -111,8 +106,8 @@ LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNod
continue;
}
// Record param user in toposort nodes.
// We only check memory side effect cnodes or Depend nodes.
if (HasMemoryEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
// We only check side effect cnodes or Depend nodes.
if (HasSideEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
for (size_t n = 1; n < cnode->size(); ++n) {
const auto &input = cnode->input(n);
auto ref_key = GetRefKey(input);

View File

@ -1518,3 +1518,47 @@ def test_multi_abs_add_assign():
outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy(), tmp.asnumpy()]
expects = numpy_out(p, i0, i1, i2)
np.testing.assert_array_equal(outputs, expects)
@security_off_wrap
def test_print_assign_print():
"""
Feature: Auto Monad
Description: Test load eliminate when umonad and iomona both exist.
Expectation: No exception.
"""
class Print(Cell):
def __init__(self):
super().__init__()
self.print = P.Print()
self.assign = P.Assign()
self.param = Parameter(Tensor(1, dtype=ms.int32), name='param')
def func(self):
self.assign(self.param, self.param * 5)
return self.param + 5
def construct(self, value):
param = self.param
self.print("param_1:", param)
res = self.func()
self.print("res:", res)
self.print("param_2:", param)
self.param = value
self.print("param_3:", param)
return res
cap = Capture()
with capture(cap):
input_x = Tensor(3, dtype=ms.int32)
expect = Tensor(10, dtype=ms.int32)
net = Print()
out = net(input_x)
time.sleep(0.1)
patterns = {'param_1:\nTensor(shape=[], dtype=Int32, value=1)\n'
'res:\nTensor(shape=[], dtype=Int32, value=10)\n'
'param_2:\nTensor(shape=[], dtype=Int32, value=5)\n'
'param_3:\nTensor(shape=[], dtype=Int32, value=3)'}
check_output(cap.output, patterns)
np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())