From d030ce2441fc6e88db811347b53771adafa15131 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Fri, 12 Aug 2022 16:41:32 +0800 Subject: [PATCH] Fix auto_monad problem in vmap scenes. --- .../jit/static_analysis/order_enforce.cc | 3 +- mindspore/core/abstract/ops/prim_others.cc | 7 +---- .../auto_monad/test_auto_monad_expression.py | 29 +++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index 11deae09854..0036dcdf808 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -222,7 +222,8 @@ class OrderEnforcer { for (size_t i = 1; i < inputs.size(); ++i) { auto &input = inputs[i]; // Skip non-ref input and update_state. - if (!IsRef(input) || input == update_state) { + // Skip Load(param, umonad) --> Ref + if (!IsRef(input) || input == update_state || IsPrimitiveCNode(input, prim::kPrimLoad)) { continue; } // The input is a ref (of parameter), find load nodes for it. diff --git a/mindspore/core/abstract/ops/prim_others.cc b/mindspore/core/abstract/ops/prim_others.cc index 2af831116af..f5bc2889fb5 100644 --- a/mindspore/core/abstract/ops/prim_others.cc +++ b/mindspore/core/abstract/ops/prim_others.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -437,11 +437,6 @@ AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &pri const AbstractBasePtrList &args_spec_list) { // Inputs: Ref/Tensor, universal CheckArgsSize(primitive->name(), args_spec_list, 2); - auto ref_abs = dyn_cast(args_spec_list[0]); - if (ref_abs != nullptr) { - // Return tensor value if input is Ref. - return ref_abs->CloneAsTensor(); - } return args_spec_list[0]->Broaden(); } diff --git a/tests/st/auto_monad/test_auto_monad_expression.py b/tests/st/auto_monad/test_auto_monad_expression.py index cef4bd0727b..05829d77bec 100644 --- a/tests/st/auto_monad/test_auto_monad_expression.py +++ b/tests/st/auto_monad/test_auto_monad_expression.py @@ -228,3 +228,32 @@ def test_load_eliminate(): net = Net() out = net(x) assert out == 5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_parameter_tuple_assign(): + """ + Feature: Auto monad feature. + Description: Parameter tuple assign. + Expectation: No exception. + """ + class Net(Cell): + def __init__(self): + super().__init__() + self.assign = P.Assign() + self.param1 = Parameter(Tensor(0), name="param1") + self.param2 = Parameter(Tensor(0), name="param2") + + def construct(self, x): + params = (self.param1, self.param2) + self.assign(params[0], x) + return params[0], params[1] + + x = Tensor(2) + net = Net() + out = net(x) + assert out[0] == 2 + assert out[1] == 0