!30465 dont evaluated to specific SymbolicKey if not direct weight parameter

Merge pull request !30465 from xychow/dont-specialze-refembed-if-not-direct-weight-master
This commit is contained in:
i-robot 2022-02-24 03:33:35 +00:00 committed by Gitee
commit e0da695ba2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 58 additions and 1 deletions

View File

@ -1296,8 +1296,20 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
return nullptr; return nullptr;
} }
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
// specific SymbolicKey.
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
// which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
// RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
bool ifEmbedIsWeight = false;
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
auto param = node_conf->node()->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
ifEmbedIsWeight = param->has_default();
}
auto refkey = key_value->cast<RefKeyPtr>(); auto refkey = key_value->cast<RefKeyPtr>();
if (refkey == nullptr) { if (refkey == nullptr || !ifEmbedIsWeight) {
auto ret = std::make_shared<AbstractScalar>(type); auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref(); auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value); MS_EXCEPTION_IF_NULL(ref_value);

View File

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" test grad ops """ """ test grad ops """
from dataclasses import dataclass
import mindspore.ops as ops import mindspore.ops as ops
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import ms_function from mindspore import ms_function
@ -115,3 +117,46 @@ def test_high_order_with_params():
first_grad = Grad(net) first_grad = Grad(net)
second_grad = GradSec(first_grad) second_grad = GradSec(first_grad)
assert second_grad(x) == (expected,) assert second_grad(x) == (expected,)
def test_reftoembed_with_two_weights():
"""
Feature: RefToEmbed can be properly be evaluated.
Description: Multiple weights with same shape and type can be evaluatd properly
even SimplifyDataStructures (one more round of Renormalize) takes effect.
Expectation: return expected value.
"""
@dataclass
class SimpleData:
a: int
def get_data(self):
return self.a
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor([2], mstype.int32), name="weight", requires_grad=True)
self.bias = Parameter(Tensor([3], mstype.int32), name="bias", requires_grad=True)
def construct(self, x):
simple = SimpleData(x)
r = self.weight * self.bias * simple.get_data()
return r
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, x):
output = self.grad(self.network, self.params)(x)
return output
expected_weight_grad = Tensor([15], mstype.int32)
expected_bias_grad = Tensor([10], mstype.int32)
net = Net()
first_grad = Grad(net)
assert first_grad(x) == (expected_weight_grad, expected_bias_grad)