forked from mindspore-Ecosystem/mindspore
!29711 return Tensor value for AbstractRef when infer EnvironGet primitive
Merge pull request !29711 from xychow/return-dflt-when-infer-environ-get
This commit is contained in:
commit
f97f16dfb0
|
@ -81,6 +81,12 @@ AbstractBasePtr InferImplEnvironGet(const AnalysisEnginePtr &, const PrimitivePt
|
|||
auto expected = key_value_track->abstract();
|
||||
MS_EXCEPTION_IF_NULL(expected);
|
||||
(void)expected->Join(dflt);
|
||||
// If expected is AbstractRef, return it's AbstractTensor as Value type other than Reference type.
|
||||
if (expected->isa<AbstractRef>()) {
|
||||
const auto &abs_ref = expected->cast<AbstractRefPtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_ref);
|
||||
return abs_ref->CloneAsTensor();
|
||||
}
|
||||
return expected;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,9 +14,11 @@
|
|||
# ============================================================================
|
||||
""" test grad ops """
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ms_function
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import ParameterTuple, Parameter
|
||||
|
||||
one = Tensor([1], mstype.int32)
|
||||
zero = Tensor([0], mstype.int32)
|
||||
|
@ -57,3 +59,59 @@ def test_pow_second_order():
|
|||
sec_grad_net = grad(grad_net)
|
||||
res = sec_grad_net(x, n)
|
||||
assert res == 30
|
||||
|
||||
|
||||
def test_high_order_with_params():
|
||||
"""
|
||||
Feature: second order test.
|
||||
Description: second order test with weight.
|
||||
Expectation: return expected value.
|
||||
net: (x ** 3) * (self.weight ** 3)
|
||||
first_grad: 3 * (x ** 2) * (self.weight ** 3)
|
||||
second_grad: 3 * (x ** 2) * 3 * (self.weight * 2)
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(Tensor([2], mstype.int32)), name="weight", requires_grad=True)
|
||||
self.n = Tensor([3], mstype.int32)
|
||||
|
||||
def construct(self, x):
|
||||
r = one
|
||||
n = self.n
|
||||
while n > zero:
|
||||
n = n - one
|
||||
r = r * x * self.weight
|
||||
return r
|
||||
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = ops.GradOperation(get_all=True, sens_param=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
output = self.grad(self.network)(x)
|
||||
return output
|
||||
|
||||
|
||||
class GradSec(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradSec, self).__init__()
|
||||
self.grad = ops.GradOperation(get_by_list=True, sens_param=False)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, x):
|
||||
output = self.grad(self.network, self.params)(x)
|
||||
return output
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor([5], mstype.int32)
|
||||
expected = Tensor([900], mstype.int32)
|
||||
net = Net()
|
||||
first_grad = Grad(net)
|
||||
second_grad = GradSec(first_grad)
|
||||
assert second_grad(x) == (expected,)
|
||||
|
|
Loading…
Reference in New Issue