From 38ac9041d6638b63fc8586056cecede10b1babe2 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Tue, 22 Feb 2022 11:28:49 +0000 Subject: [PATCH] dont evaluated to specific SymbolicKey if not direct weight parameter --- .../pipeline/jit/static_analysis/prim.cc | 14 +++++- .../ops/ascend/test_aicpu_ops/test_env_ops.py | 45 +++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 18679040138..b1ab74428b6 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1296,8 +1296,20 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; return nullptr; } + // Check if the input of RefEmbed is a weight parameter, if not, don't create the + // specific SymbolicKey. + // Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph + // which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the + // RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey. + // Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey. + bool ifEmbedIsWeight = false; + if (node_conf->node() != nullptr && node_conf->node()->isa()) { + auto param = node_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(param); + ifEmbedIsWeight = param->has_default(); + } auto refkey = key_value->cast(); - if (refkey == nullptr) { + if (refkey == nullptr || !ifEmbedIsWeight) { auto ret = std::make_shared(type); auto ref_value = ref_abs->ref(); MS_EXCEPTION_IF_NULL(ref_value); diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_env_ops.py b/tests/st/ops/ascend/test_aicpu_ops/test_env_ops.py index 08d8ddc6384..459765f239d 100644 --- a/tests/st/ops/ascend/test_aicpu_ops/test_env_ops.py +++ b/tests/st/ops/ascend/test_aicpu_ops/test_env_ops.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ """ test grad ops """ +from dataclasses import dataclass + import mindspore.ops as ops import mindspore.nn as nn from mindspore import ms_function @@ -115,3 +117,46 @@ def test_high_order_with_params(): first_grad = Grad(net) second_grad = GradSec(first_grad) assert second_grad(x) == (expected,) + + +def test_reftoembed_with_two_weights(): + """ + Feature: RefToEmbed can be properly be evaluated. + Description: Multiple weights with same shape and type can be evaluatd properly + even SimplifyDataStructures (one more round of Renormalize) takes effect. + Expectation: return expected value. + """ + @dataclass + class SimpleData: + a: int + + def get_data(self): + return self.a + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.weight = Parameter(Tensor([2], mstype.int32), name="weight", requires_grad=True) + self.bias = Parameter(Tensor([3], mstype.int32), name="bias", requires_grad=True) + + def construct(self, x): + simple = SimpleData(x) + r = self.weight * self.bias * simple.get_data() + return r + + class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = ops.GradOperation(get_by_list=True, sens_param=False) + self.network = network + self.params = ParameterTuple(network.trainable_params()) + + def construct(self, x): + output = self.grad(self.network, self.params)(x) + return output + + expected_weight_grad = Tensor([15], mstype.int32) + expected_bias_grad = Tensor([10], mstype.int32) + net = Net() + first_grad = Grad(net) + assert first_grad(x) == (expected_weight_grad, expected_bias_grad)