forked from mindspore-Ecosystem/mindspore
!30465 dont evaluated to specific SymbolicKey if not direct weight parameter
Merge pull request !30465 from xychow/dont-specialze-refembed-if-not-direct-weight-master
This commit is contained in:
commit
e0da695ba2
|
@ -1296,8 +1296,20 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
|
||||
// specific SymbolicKey.
|
||||
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
|
||||
// which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
|
||||
// RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
|
||||
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
|
||||
bool ifEmbedIsWeight = false;
|
||||
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
|
||||
auto param = node_conf->node()->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
ifEmbedIsWeight = param->has_default();
|
||||
}
|
||||
auto refkey = key_value->cast<RefKeyPtr>();
|
||||
if (refkey == nullptr) {
|
||||
if (refkey == nullptr || !ifEmbedIsWeight) {
|
||||
auto ret = std::make_shared<AbstractScalar>(type);
|
||||
auto ref_value = ref_abs->ref();
|
||||
MS_EXCEPTION_IF_NULL(ref_value);
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test grad ops """
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ms_function
|
||||
|
@ -115,3 +117,46 @@ def test_high_order_with_params():
|
|||
first_grad = Grad(net)
|
||||
second_grad = GradSec(first_grad)
|
||||
assert second_grad(x) == (expected,)
|
||||
|
||||
|
||||
def test_reftoembed_with_two_weights():
|
||||
"""
|
||||
Feature: RefToEmbed can be properly be evaluated.
|
||||
Description: Multiple weights with same shape and type can be evaluatd properly
|
||||
even SimplifyDataStructures (one more round of Renormalize) takes effect.
|
||||
Expectation: return expected value.
|
||||
"""
|
||||
@dataclass
|
||||
class SimpleData:
|
||||
a: int
|
||||
|
||||
def get_data(self):
|
||||
return self.a
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor([2], mstype.int32), name="weight", requires_grad=True)
|
||||
self.bias = Parameter(Tensor([3], mstype.int32), name="bias", requires_grad=True)
|
||||
|
||||
def construct(self, x):
|
||||
simple = SimpleData(x)
|
||||
r = self.weight * self.bias * simple.get_data()
|
||||
return r
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, x):
|
||||
output = self.grad(self.network, self.params)(x)
|
||||
return output
|
||||
|
||||
expected_weight_grad = Tensor([15], mstype.int32)
|
||||
expected_bias_grad = Tensor([10], mstype.int32)
|
||||
net = Net()
|
||||
first_grad = Grad(net)
|
||||
assert first_grad(x) == (expected_weight_grad, expected_bias_grad)
|
||||
|
|
Loading…
Reference in New Issue