forked from mindspore-Ecosystem/mindspore
!30465 dont evaluated to specific SymbolicKey if not direct weight parameter
Merge pull request !30465 from xychow/dont-specialze-refembed-if-not-direct-weight-master
This commit is contained in:
commit
e0da695ba2
|
@ -1296,8 +1296,20 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
|
MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
// Check if the input of RefEmbed is a weight parameter, if not, don't create the
|
||||||
|
// specific SymbolicKey.
|
||||||
|
// Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
|
||||||
|
// which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
|
||||||
|
// RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
|
||||||
|
// Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
|
||||||
|
bool ifEmbedIsWeight = false;
|
||||||
|
if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
|
||||||
|
auto param = node_conf->node()->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
ifEmbedIsWeight = param->has_default();
|
||||||
|
}
|
||||||
auto refkey = key_value->cast<RefKeyPtr>();
|
auto refkey = key_value->cast<RefKeyPtr>();
|
||||||
if (refkey == nullptr) {
|
if (refkey == nullptr || !ifEmbedIsWeight) {
|
||||||
auto ret = std::make_shared<AbstractScalar>(type);
|
auto ret = std::make_shared<AbstractScalar>(type);
|
||||||
auto ref_value = ref_abs->ref();
|
auto ref_value = ref_abs->ref();
|
||||||
MS_EXCEPTION_IF_NULL(ref_value);
|
MS_EXCEPTION_IF_NULL(ref_value);
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test grad ops """
|
""" test grad ops """
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import mindspore.ops as ops
|
import mindspore.ops as ops
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import ms_function
|
from mindspore import ms_function
|
||||||
|
@ -115,3 +117,46 @@ def test_high_order_with_params():
|
||||||
first_grad = Grad(net)
|
first_grad = Grad(net)
|
||||||
second_grad = GradSec(first_grad)
|
second_grad = GradSec(first_grad)
|
||||||
assert second_grad(x) == (expected,)
|
assert second_grad(x) == (expected,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reftoembed_with_two_weights():
|
||||||
|
"""
|
||||||
|
Feature: RefToEmbed can be properly be evaluated.
|
||||||
|
Description: Multiple weights with same shape and type can be evaluatd properly
|
||||||
|
even SimplifyDataStructures (one more round of Renormalize) takes effect.
|
||||||
|
Expectation: return expected value.
|
||||||
|
"""
|
||||||
|
@dataclass
|
||||||
|
class SimpleData:
|
||||||
|
a: int
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
return self.a
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.weight = Parameter(Tensor([2], mstype.int32), name="weight", requires_grad=True)
|
||||||
|
self.bias = Parameter(Tensor([3], mstype.int32), name="bias", requires_grad=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
simple = SimpleData(x)
|
||||||
|
r = self.weight * self.bias * simple.get_data()
|
||||||
|
return r
|
||||||
|
|
||||||
|
class Grad(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(Grad, self).__init__()
|
||||||
|
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
|
||||||
|
self.network = network
|
||||||
|
self.params = ParameterTuple(network.trainable_params())
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
output = self.grad(self.network, self.params)(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
expected_weight_grad = Tensor([15], mstype.int32)
|
||||||
|
expected_bias_grad = Tensor([10], mstype.int32)
|
||||||
|
net = Net()
|
||||||
|
first_grad = Grad(net)
|
||||||
|
assert first_grad(x) == (expected_weight_grad, expected_bias_grad)
|
||||||
|
|
Loading…
Reference in New Issue