!30465 dont evaluated to specific SymbolicKey if not direct weight parameter

Merge pull request !30465 from xychow/dont-specialze-refembed-if-not-direct-weight-master
This commit is contained in:
i-robot 2022-02-24 03:33:35 +00:00 committed by Gitee
commit e0da695ba2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 58 additions and 1 deletions

View File

@ -1296,8 +1296,20 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
return nullptr;
}
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
// specific SymbolicKey.
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
// which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
// RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
bool ifEmbedIsWeight = false;
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
auto param = node_conf->node()->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default();
}
auto refkey = key_value->cast<RefKeyPtr>();
if (refkey == nullptr) {
if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
""" test grad ops """
from dataclasses import dataclass
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import ms_function
@ -115,3 +117,46 @@ def test_high_order_with_params():
first_grad = Grad(net)
second_grad = GradSec(first_grad)
assert second_grad(x) == (expected,)
def test_reftoembed_with_two_weights():
"""
Feature: RefToEmbed can be properly be evaluated.
Description: Multiple weights with same shape and type can be evaluatd properly
even SimplifyDataStructures (one more round of Renormalize) takes effect.
Expectation: return expected value.
"""
@dataclass
class SimpleData:
a: int
def get_data(self):
return self.a
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor([2], mstype.int32), name="weight", requires_grad=True)
self.bias = Parameter(Tensor([3], mstype.int32), name="bias", requires_grad=True)
def construct(self, x):
simple = SimpleData(x)
r = self.weight * self.bias * simple.get_data()
return r
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, x):
output = self.grad(self.network, self.params)(x)
return output
expected_weight_grad = Tensor([15], mstype.int32)
expected_bias_grad = Tensor([10], mstype.int32)
net = Net()
first_grad = Grad(net)
assert first_grad(x) == (expected_weight_grad, expected_bias_grad)