forked from mindspore-Ecosystem/mindspore
!2875 add grad all in pynative
Merge pull request !2875 from wangqiuliang/add-grad-all-in-pynative
This commit is contained in:
commit
d925c52bb1
|
@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
|
|||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "training not paramter_tuple";
|
||||
MS_LOG(DEBUG) << "training not paramter_tuple";
|
||||
}
|
||||
return w_args;
|
||||
}
|
||||
|
|
|
@ -181,6 +181,9 @@ class Tensor(Tensor_):
|
|||
def __imod__(self, other):
|
||||
return self.__mod__(other)
|
||||
|
||||
def __pow__(self, other):
|
||||
return tensor_operator_registry.get('__pow__')(self, other)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return tensor_operator_registry.get('__floordiv__')(self, other)
|
||||
|
||||
|
|
|
@ -176,7 +176,10 @@ class _Context:
|
|||
self._context_switches.push(True, None)
|
||||
else:
|
||||
if self.enable_debug_runtime:
|
||||
self.set_backend_policy("ge")
|
||||
if self.device_target == "CPU":
|
||||
self.set_backend_policy("vm")
|
||||
else:
|
||||
self.set_backend_policy("ge")
|
||||
self._context_switches.push(False, None)
|
||||
|
||||
def set_backend_policy(self, policy):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import time
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
import numpy
|
||||
from mindspore import log as logger
|
||||
from .. import context
|
||||
from ..common import dtype as mstype
|
||||
|
@ -211,6 +212,9 @@ class Cell:
|
|||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
self.init_parameters_data()
|
||||
orign_grad = []
|
||||
if self.requires_grad is True:
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
"""Basic composite operations."""
|
||||
from functools import partial
|
||||
from types import FunctionType
|
||||
|
||||
from mindspore import context
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
|
||||
|
@ -25,6 +26,7 @@ from ...common import dtype as mstype
|
|||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
from .. import functional as F
|
||||
from ...common.parameter import Parameter
|
||||
from ...common.tensor import Tensor
|
||||
|
||||
|
||||
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
@ -114,37 +116,48 @@ class GradOperation(GradOperation_):
|
|||
self.fn = None
|
||||
self.need_forward = False
|
||||
|
||||
def _pynative_forward_run(self, args, fn):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
if self.sens_param:
|
||||
args = args[:-1]
|
||||
if isinstance(fn, FunctionType):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(fn, *args)
|
||||
output = fn(*args)
|
||||
_pynative_exec.end_graph(fn, output, *args)
|
||||
else:
|
||||
if fn.is_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
if not fn.is_run:
|
||||
self.need_forward = True
|
||||
print("already has forward run before grad by user")
|
||||
if self.need_forward:
|
||||
fn.set_grad()
|
||||
fn(*args)
|
||||
|
||||
def __call__(self, fn, weights=None):
|
||||
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
|
||||
if self.grad_fn is None or self.fn != fn:
|
||||
if self.get_by_list:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if self.get_by_list:
|
||||
@ms_function(obj=fn)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@_wrap_func
|
||||
@ms_function(obj=fn)
|
||||
def after_grad(*args):
|
||||
if fn.is_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
if not fn.is_run:
|
||||
self.need_forward = True
|
||||
print("already has forward run before grad by user")
|
||||
if self.need_forward:
|
||||
fn.set_grad()
|
||||
if self.sens_param:
|
||||
f_args = args[:-1]
|
||||
fn(*f_args)
|
||||
else:
|
||||
fn(*args)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args)
|
||||
out = _pynative_exec(*args)
|
||||
_pynative_exec.clear()
|
||||
return out
|
||||
return grad_(fn)(*args)
|
||||
else:
|
||||
@ms_function(obj=fn)
|
||||
@_wrap_func
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
for arg in args:
|
||||
if not isinstance(arg, Tensor):
|
||||
raise TypeError("grad inputs should be tensor in pynative mode")
|
||||
self._pynative_forward_run(args, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args)
|
||||
out = _pynative_exec(*args)
|
||||
_pynative_exec.clear()
|
||||
return out
|
||||
self.grad_fn = after_grad
|
||||
self.fn = fn
|
||||
return self.grad_fn
|
||||
|
|
|
@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub)
|
|||
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||
tensor_operator_registry.register('__truediv__', tensor_div)
|
||||
tensor_operator_registry.register('__mod__', tensor_mod)
|
||||
tensor_operator_registry.register('__pow__', tensor_pow)
|
||||
tensor_operator_registry.register('__floordiv__', tensor_floordiv)
|
||||
#ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
|
|
|
@ -228,6 +228,7 @@ def test_biasadd_3d():
|
|||
error = np.ones(shape=[3, 4, 8]) * 1.0e-6
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net = BiasAdd()
|
||||
net.set_grad()
|
||||
result = net(x, b)
|
||||
diff = result.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
|
|
|
@ -45,6 +45,7 @@ def test_net_infer():
|
|||
|
||||
|
||||
def test_assign_in_while():
|
||||
context.set_context(device_target="Ascend")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input_shape):
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
|
@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from ....mindspore_test_framework.utils.bprop_util import bprop
|
||||
from .....mindspore_test_framework.utils.bprop_util import bprop
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
context.set_context(device_target="CPU")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
def teardown_module(module):
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
class MulAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -45,7 +49,9 @@ class MulAdd(nn.Cell):
|
|||
|
||||
def test_grad_mul_add():
|
||||
mul_add = MulAdd()
|
||||
assert C.grad_all(mul_add)(1, 2) == (2, 4)
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
y = Tensor(2, dtype=ms.int32)
|
||||
assert C.grad_all(mul_add)(x, y) == (2, 4)
|
||||
|
||||
|
||||
class InlineMulADD(nn.Cell):
|
||||
|
@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell):
|
|||
|
||||
def test_grad_inline_mul_add():
|
||||
inline_mul_add = InlineMulADD()
|
||||
assert C.grad_all(inline_mul_add)(1, 2) == (3, 6)
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
y = Tensor(2, dtype=ms.int32)
|
||||
assert C.grad_all(inline_mul_add)(x, y) == (3, 6)
|
||||
|
||||
|
||||
class WithParameter(nn.Cell):
|
||||
|
@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell):
|
|||
|
||||
def test_with_no_bprop():
|
||||
with_no_bprop = WithNoBprop()
|
||||
assert C.grad_all(with_no_bprop)(1, 2) == (2, 1)
|
||||
x = Tensor(1, dtype=ms.int32)
|
||||
y = Tensor(2, dtype=ms.int32)
|
||||
assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
|
||||
|
||||
|
||||
def test_grad_in_bprop_1():
|
|
@ -19,21 +19,27 @@
|
|||
@Desc :
|
||||
"""
|
||||
import logging
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.api import ms_function, _executor
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.ops.functional import tensor_add
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
# pylint: disable=W0613
|
||||
# pylint: disable=W0613,W0612
|
||||
# W0613: unused-argument
|
||||
|
||||
|
||||
log = logging.getLogger("test")
|
||||
log.setLevel(level=logging.ERROR)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
# Test case: use the parse obj interface use default parameter
|
||||
|
@ -135,3 +141,113 @@ def test_net_with_ndarray():
|
|||
input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
|
||||
|
||||
net(ms.Tensor(input_data))
|
||||
|
||||
|
||||
def test_bprop_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputNum(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
|
||||
|
||||
def __call__(self, x, y):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape, yshape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, y_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputNum)
|
||||
def get_bprop_with_wrong_output_num(self):
|
||||
"""Generate bprop for BpropWithWrongOutputNum"""
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
return (dout,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputNumCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputNumCell, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return BpropWithWrongOutputNum()(x, y)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
||||
|
||||
def test_bprop_with_wrong_output_type():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputType(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputType)
|
||||
def get_bprop_with_wrong_output_type(self):
|
||||
"""Generate bprop for BpropWithWrongOutputType"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (1,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputTypeCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputTypeCell, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return BpropWithWrongOutputType()(x)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
|
||||
|
||||
def test_bprop_with_wrong_output_shape():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputShape(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputShape)
|
||||
def get_bprop_with_wrong_output_shape(self):
|
||||
"""Generate bprop for BpropWithWrongOutputShape"""
|
||||
ones = Tensor(np.ones([2,]).astype(np.int32))
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (ones,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputShapeCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputShapeCell, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return BpropWithWrongOutputShape()(x)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
net = BpropWithWrongOutputShapeCell()
|
||||
net.set_grad()
|
||||
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
|
|
|
@ -78,3 +78,9 @@ def test_tensor_imul():
|
|||
y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32))
|
||||
x *= y
|
||||
assert x.asnumpy()[0][0][0][0] == 1.0
|
||||
|
||||
|
||||
def test_tensor_pow():
|
||||
x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2)
|
||||
y = x ** 3
|
||||
assert y.asnumpy()[0][0][0][0] == 8.0
|
||||
|
|
|
@ -89,7 +89,11 @@ def test_scalar_cast_grad():
|
|||
output = F.scalar_cast(x, input_t)
|
||||
return output
|
||||
|
||||
gfn = C.grad(fx_cast)(input_x)
|
||||
@ms_function
|
||||
def grad_fx_cast(input_x):
|
||||
return C.grad(fx_cast)(input_x)
|
||||
|
||||
gfn = grad_fx_cast(input_x)
|
||||
expect_dx = 1
|
||||
assert gfn == expect_dx
|
||||
|
||||
|
@ -133,25 +137,6 @@ def test_transpose_grad():
|
|||
assert np.all(gout[0].asnumpy() == expect)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_squeeze_grad():
|
||||
""" test_squeeze_grad """
|
||||
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
|
||||
squeeze = P.Squeeze(2)
|
||||
|
||||
def fn(x):
|
||||
output = squeeze(x)
|
||||
return output
|
||||
|
||||
out = fn(input_tensor)
|
||||
gfn = grad_all_with_sens(fn)
|
||||
sens = Tensor(np.ones_like(out.asnumpy()))
|
||||
args = [input_tensor, sens]
|
||||
gout = gfn(*args)
|
||||
expect = np.ones([3, 2, 1])
|
||||
assert np.all(gout[0].asnumpy() == expect)
|
||||
|
||||
|
||||
def test_select_grad():
|
||||
""" test_select_grad """
|
||||
select = P.Select()
|
||||
|
@ -176,6 +161,25 @@ def test_select_grad():
|
|||
assert np.all(gout[2].asnumpy() == expect_y)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_squeeze_grad():
|
||||
""" test_squeeze_grad """
|
||||
input_tensor = Tensor(np.ones(shape=[3, 2, 1]))
|
||||
squeeze = P.Squeeze(2)
|
||||
|
||||
def fn(x):
|
||||
output = squeeze(x)
|
||||
return output
|
||||
|
||||
out = fn(input_tensor)
|
||||
gfn = grad_all_with_sens(fn)
|
||||
sens = Tensor(np.ones_like(out.asnumpy()))
|
||||
args = [input_tensor, sens]
|
||||
gout = gfn(*args)
|
||||
expect = np.ones([3, 2, 1])
|
||||
assert np.all(gout[0].asnumpy() == expect)
|
||||
|
||||
|
||||
def test_SubGrad():
|
||||
""" test_SubGrad """
|
||||
input_x = Tensor(np.array([[2, 2]]))
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.utils.check_gradient import (
|
||||
ms_function, check_jacobian, Tensor, NNGradChecker,
|
||||
|
@ -156,14 +155,14 @@ def test_if_always_true():
|
|||
@non_graph_engine
|
||||
def test_f():
|
||||
""" test_f """
|
||||
res = mainf(3, 2)
|
||||
res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
|
||||
assert res == (2, 3)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_grad_add_mul():
|
||||
""" test_grad_add_mul """
|
||||
res = grad_add_mul(3, 2)
|
||||
res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
|
||||
assert res == (2, 7)
|
||||
|
||||
|
||||
|
@ -262,17 +261,19 @@ def test_if_tensor():
|
|||
assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
|
||||
|
||||
|
||||
@ms_function
|
||||
def rec(x):
|
||||
""" rec """
|
||||
if x > 0:
|
||||
return rec(x - 1)
|
||||
return x
|
||||
|
||||
@ms_function
|
||||
def grad_rec(input_x):
|
||||
return C.grad(rec)(input_x)
|
||||
|
||||
def test_grad_rec():
|
||||
""" test_grad_rec """
|
||||
res = C.grad(rec)(10)
|
||||
res = grad_rec(3)
|
||||
assert res == 1
|
||||
|
||||
|
||||
|
@ -282,7 +283,6 @@ def test_me_rec():
|
|||
assert res == 0
|
||||
|
||||
|
||||
@ms_function
|
||||
def t2_while(x, y):
|
||||
out = y - x
|
||||
i = 0
|
||||
|
@ -298,8 +298,10 @@ def test_while2():
|
|||
|
||||
|
||||
def test_grad_while2():
|
||||
res = C.grad(t2_while)(2, 3)
|
||||
assert res == 3
|
||||
@ms_function
|
||||
def df_t2_while(input_x, input_y):
|
||||
return C.grad(t2_while)(input_x, input_y)
|
||||
assert df_t2_while(2, 3) == 3
|
||||
|
||||
|
||||
def if_test(a, b):
|
||||
|
@ -316,7 +318,7 @@ def grad_if(x, y):
|
|||
|
||||
def test_grad_if():
|
||||
""" test_grad_if """
|
||||
assert grad_if(5, 4) == (3, 0)
|
||||
assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
|
||||
|
||||
|
||||
# While loop is not unrolled in forward and backward graphs.
|
||||
|
@ -421,7 +423,7 @@ def grad_while(x):
|
|||
|
||||
def test_grad_while():
|
||||
""" test_grad_while """
|
||||
assert grad_while(5) == (60,)
|
||||
assert grad_while(Tensor(5, dtype=ms.int32)) == (60,)
|
||||
|
||||
|
||||
@ms_function
|
||||
|
@ -438,8 +440,10 @@ def test_factorial():
|
|||
|
||||
|
||||
def test_grad_factorial():
|
||||
res = C.grad(factorial)(3)
|
||||
assert res == 11
|
||||
@ms_function
|
||||
def df_factorial(x):
|
||||
return C.grad(factorial)(x)
|
||||
assert df_factorial(3) == 11
|
||||
|
||||
|
||||
@ms_function
|
||||
|
@ -513,7 +517,7 @@ def _for(x):
|
|||
ret = ret * i
|
||||
return ret
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_for(x):
|
||||
""" grad_for """
|
||||
return C.grad_all(_for)(x)
|
||||
|
@ -786,7 +790,10 @@ def multi_outputs(x, y):
|
|||
|
||||
|
||||
def test_grad_multi_outputs():
|
||||
assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)
|
||||
@ms_function
|
||||
def df_multi_outputs(x, y):
|
||||
return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1))
|
||||
assert df_multi_outputs(2, 3) == (4, 4)
|
||||
|
||||
|
||||
@ms_function
|
||||
|
@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y):
|
|||
|
||||
|
||||
def test_grad_refactor_simple_1():
|
||||
assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)
|
||||
assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2)
|
||||
|
||||
|
||||
def grad_refactor_simple_2(x, y, z):
|
||||
|
@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z):
|
|||
|
||||
|
||||
def test_grad_refactor_simple_2():
|
||||
assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)
|
||||
x = Tensor(2, dtype=ms.int32)
|
||||
y = Tensor(3, dtype=ms.int32)
|
||||
z = Tensor(0, dtype=ms.int32)
|
||||
assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7)
|
||||
|
||||
|
||||
def grad_refactor_1(a, b):
|
||||
|
@ -835,7 +845,7 @@ def grad_refactor_1(a, b):
|
|||
|
||||
|
||||
def test_grad_refactor_1():
|
||||
assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)
|
||||
assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2)
|
||||
|
||||
|
||||
def grad_refactor_2(a, b):
|
||||
|
@ -848,7 +858,7 @@ def grad_refactor_2(a, b):
|
|||
|
||||
|
||||
def test_grad_refactor_2():
|
||||
assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)
|
||||
assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54)
|
||||
|
||||
|
||||
def grad_refactor_3(a):
|
||||
|
@ -859,7 +869,10 @@ def grad_refactor_3(a):
|
|||
|
||||
|
||||
def test_grad_refactor_3():
|
||||
assert C.grad_all(grad_refactor_3)(3) == (3,)
|
||||
@ms_function
|
||||
def df_refactor_3(x):
|
||||
return C.grad_all(grad_refactor_3)(x)
|
||||
assert df_refactor_3(3) == (3,)
|
||||
|
||||
|
||||
def grad_refactor_4(a):
|
||||
|
@ -870,7 +883,7 @@ def grad_refactor_4(a):
|
|||
|
||||
|
||||
def test_grad_refactor_4():
|
||||
assert C.grad_all(grad_refactor_4)(4) == (3,)
|
||||
assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,)
|
||||
|
||||
|
||||
def grad_refactor_5(a):
|
||||
|
@ -881,7 +894,10 @@ def grad_refactor_5(a):
|
|||
|
||||
|
||||
def test_grad_refactor_5():
|
||||
assert C.grad_all(grad_refactor_5)(1) == (1,)
|
||||
@ms_function
|
||||
def df_refactor_5(x):
|
||||
return C.grad_all(grad_refactor_5)(x)
|
||||
assert df_refactor_5(1) == (1,)
|
||||
|
||||
|
||||
def grad_refactor_6(a, b):
|
||||
|
@ -892,7 +908,7 @@ def grad_refactor_6(a, b):
|
|||
|
||||
|
||||
def test_grad_refactor_6():
|
||||
assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
|
||||
assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1)
|
||||
|
||||
|
||||
def grad_refactor_while(x):
|
||||
|
@ -904,7 +920,10 @@ def grad_refactor_while(x):
|
|||
|
||||
|
||||
def test_grad_refactor_9():
|
||||
assert C.grad_all(grad_refactor_while)(3) == (6,)
|
||||
@ms_function
|
||||
def df_refactor_while(input_x):
|
||||
return C.grad_all(grad_refactor_while)(input_x)
|
||||
assert df_refactor_while(3) == (6,)
|
||||
|
||||
|
||||
def grad_refactor__while_1(x):
|
||||
|
@ -919,7 +938,7 @@ def grad_refactor__while_1(x):
|
|||
|
||||
def test_grad_refactor_10():
|
||||
""" test_grad_while """
|
||||
assert C.grad_all(grad_refactor__while_1)(5) == (60,)
|
||||
assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,)
|
||||
|
||||
|
||||
def test_grad_refactor_11():
|
||||
|
@ -985,7 +1004,10 @@ def grad_refactor_14(a, b):
|
|||
|
||||
|
||||
def test_grad_refactor_14():
|
||||
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
|
||||
@ms_function
|
||||
def df_refactor_14(x, y):
|
||||
return C.grad_all(grad_refactor_14)(x, y)
|
||||
assert df_refactor_14(2, 3) == (3, 9)
|
||||
|
||||
|
||||
# pylint: disable=using-constant-test
|
||||
|
@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline():
|
|||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||
grads = C.grad_all(network)(inp)
|
||||
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
|
||||
|
||||
|
||||
def test_bprop_with_wrong_output_num():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputNum(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
|
||||
|
||||
def __call__(self, x, y):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape, yshape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, y_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputNum)
|
||||
def get_bprop_with_wrong_output_num(self):
|
||||
"""Generate bprop for BpropWithWrongOutputNum"""
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
return (dout,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputNumCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputNumCell, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return BpropWithWrongOutputNum()(x, y)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
|
||||
|
||||
def test_bprop_with_wrong_output_type():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputType(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputType)
|
||||
def get_bprop_with_wrong_output_type(self):
|
||||
"""Generate bprop for BpropWithWrongOutputType"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (1,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputTypeCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputTypeCell, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return BpropWithWrongOutputType()(x)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
|
||||
|
||||
def test_bprop_with_wrong_output_shape():
|
||||
context.set_context(check_bprop=True)
|
||||
class BpropWithWrongOutputShape(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type
|
||||
|
||||
@bprop_getters.register(BpropWithWrongOutputShape)
|
||||
def get_bprop_with_wrong_output_shape(self):
|
||||
"""Generate bprop for BpropWithWrongOutputShape"""
|
||||
ones = Tensor(np.ones([2,]).astype(np.int32))
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (ones,)
|
||||
|
||||
return bprop
|
||||
|
||||
class BpropWithWrongOutputShapeCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BpropWithWrongOutputShapeCell, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
return BpropWithWrongOutputShape()(x)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
@ -154,22 +155,47 @@ def test_hook():
|
|||
print(loss_output.asnumpy().shape)
|
||||
|
||||
|
||||
bprop_debug = False
|
||||
|
||||
class MulAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MulAdd, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x + y
|
||||
return 2 * x * x + y * y
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
assert (x == 1)
|
||||
assert (y == 2)
|
||||
assert (out == 4)
|
||||
assert (dout == 1)
|
||||
return 3 * dout, 2 * y
|
||||
global bprop_debug
|
||||
bprop_debug = True
|
||||
return dout, 2 * y
|
||||
|
||||
|
||||
def test_custom_bprop():
|
||||
mul_add = MulAdd()
|
||||
mul_add.bprop_debug = True
|
||||
assert C.grad_all(mul_add)(1, 2) == (3, 4)
|
||||
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
|
||||
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
|
||||
C.grad_all(mul_add)(x, y)
|
||||
assert bprop_debug
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x * x + y * y
|
||||
|
||||
def test_grad_all():
|
||||
net = Net()
|
||||
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
|
||||
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
|
||||
res = C.grad_all(net)(x, y)
|
||||
print(res)
|
||||
|
||||
def test_check_input():
|
||||
net = Net()
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([2, 3, 4])
|
||||
with pytest.raises(TypeError):
|
||||
net(x, y)
|
||||
|
|
|
@ -46,6 +46,7 @@ def test_InsertGradientOf_1():
|
|||
c = x * y
|
||||
return c
|
||||
|
||||
@ms_function
|
||||
def f(x, y):
|
||||
return C.grad_all(stop_test)(x, y)
|
||||
|
||||
|
@ -80,6 +81,7 @@ def test_InsertGradientOf_2():
|
|||
def f(x, y):
|
||||
return clip_test(x, y)
|
||||
|
||||
@ms_function
|
||||
def fd(x, y):
|
||||
return C.grad_all(clip_test)(x, y)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
|
@ -81,16 +82,24 @@ def stop_test4(x, y):
|
|||
return e
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_stop_test(x, y):
|
||||
""" grad_stop_test """
|
||||
return C.grad_all(stop_test2)(x, y)
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_stop_test1(x, y):
|
||||
""" grad_stop_test1 """
|
||||
return C.grad_all(stop_test3)(x, y)
|
||||
|
||||
|
||||
@ms_function
|
||||
def grad_stop_test5(x, y):
|
||||
""" grad_stop_test5 """
|
||||
return C.grad_all(stop_test5)(x, y)
|
||||
|
||||
|
||||
def test_stop():
|
||||
""" test_stop """
|
||||
print("test_stop:", grad_stop_test(1, 1))
|
||||
|
@ -103,7 +112,7 @@ def test_stop1():
|
|||
|
||||
def test_stop5():
|
||||
""" test_stop1 """
|
||||
print("test_stop5:", C.grad_all(stop_test5)(2, 3))
|
||||
print("test_stop5:", grad_stop_test5(2, 3))
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
|
@ -247,7 +256,7 @@ def test_stop_gradient_4():
|
|||
def stop_test(x):
|
||||
return stop_gradient(x)
|
||||
|
||||
assert C.grad_all(stop_test)(1) == (0,)
|
||||
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,)
|
||||
|
||||
|
||||
def test_stop_gradient_5():
|
||||
|
@ -257,7 +266,7 @@ def test_stop_gradient_5():
|
|||
ret = x + y
|
||||
return ret
|
||||
|
||||
assert C.grad_all(stop_test)(1) == (1,)
|
||||
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
|
||||
|
||||
|
||||
def test_stop_gradient_6():
|
||||
|
@ -266,7 +275,7 @@ def test_stop_gradient_6():
|
|||
ret = stop_gradient(ret)
|
||||
return ret
|
||||
|
||||
assert C.grad_all(stop_test)(1, 3) == (0, 0)
|
||||
assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
|
||||
|
||||
|
||||
class PrimWithMultiOutputs(PrimitiveWithInfer):
|
||||
|
|
Loading…
Reference in New Issue