!29711 return Tensor value for AbstractRef when infer EnvironGet primitive

Merge pull request !29711 from xychow/return-dflt-when-infer-environ-get
This commit is contained in:
i-robot 2022-02-11 05:09:42 +00:00 committed by Gitee
commit f97f16dfb0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 64 additions and 0 deletions

View File

@ -81,6 +81,12 @@ AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePt
auto expected = key_value_track->abstract();
MS_EXCEPTION_IF_NULL(expected);
(void)expected->Join(dflt);
// If expected is AbstractRef, return it's AbstractTensor as Value type other than Reference type.
if (expected->isa<AbstractRef>()) {
const auto &abs_ref = expected->cast<AbstractRefPtr>();
MS_EXCEPTION_IF_NULL(abs_ref);
return abs_ref->CloneAsTensor();
}
return expected;
}

View File

@ -14,9 +14,11 @@
# ============================================================================
""" test grad ops """
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import ms_function
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore import ParameterTuple, Parameter
one = Tensor([1], mstype.int32)
zero = Tensor([0], mstype.int32)
@ -57,3 +59,59 @@ def test_pow_second_order():
sec_grad_net = grad(grad_net)
res = sec_grad_net(x, n)
assert res == 30
def test_high_order_with_params():
"""
Feature: second order test.
Description: second order test with weight.
Expectation: return expected value.
net: (x ** 3) * (self.weight ** 3)
first_grad: 3 * (x ** 2) * (self.weight ** 3)
second_grad: 3 * (x ** 2) * 3 * (self.weight * 2)
"""
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.weight = Parameter(Tensor(Tensor([2], mstype.int32)), name="weight", requires_grad=True)
self.n = Tensor([3], mstype.int32)
def construct(self, x):
r = one
n = self.n
while n > zero:
n = n - one
r = r * x * self.weight
return r
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = ops.GradOperation(get_all=True, sens_param=False)
self.network = network
def construct(self, x):
output = self.grad(self.network)(x)
return output
class GradSec(nn.Cell):
def __init__(self, network):
super(GradSec, self).__init__()
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
self.network = network
self.params = ParameterTuple(network.trainable_params())
def construct(self, x):
output = self.grad(self.network, self.params)(x)
return output
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
expected = Tensor([900], mstype.int32)
net = Net()
first_grad = Grad(net)
second_grad = GradSec(first_grad)
assert second_grad(x) == (expected,)