!22397 Add JVP for forward mode auto diff.

Merge pull request !22397 from LiangZhibo/fwd2
This commit is contained in:
i-robot 2021-08-26 11:28:22 +00:00 committed by Gitee
commit 8c6d4a05fc
11 changed files with 652 additions and 16 deletions

View File

@ -438,6 +438,26 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
.def(py::init<bool>(), py::arg("reverse")); .def(py::init<bool>(), py::arg("reverse"));
})); }));
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
for (size_t i = 0; i < tuple->size(); ++i) {
if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
!((*tuple)[i]->isa<abstract::AbstractTuple>() &&
CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
return false;
}
}
return true;
}
bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()) ||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
}
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const { FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
MS_EXCEPTION_IF_NULL(sequeue); MS_EXCEPTION_IF_NULL(sequeue);
@ -457,10 +477,7 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
} }
if (tail_type_ == kGradFirst) { if (tail_type_ == kGradFirst) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr && if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()))) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))})); ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
} else { } else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{}))); ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
@ -597,7 +614,7 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
} }
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params, const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args) { const std::vector<AnfNodePtr> &weight_args) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>(); FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true); k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
@ -620,13 +637,13 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))}); auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))}); auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
GradByParameter(k_child, f_app, bprop, weights_node); GradByParameter(k_child, f_app, bprop, weights_node, enable_tuple_grad);
return k_child; return k_child;
} }
// Do grad by the parameter of GradOperation. // Do grad by the parameter of GradOperation.
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights) { const AnfNodePtr &weights, bool enable_tuple_grad) {
MS_EXCEPTION_IF_NULL(k_child); MS_EXCEPTION_IF_NULL(k_child);
AnfNodePtr bprop_arg = nullptr; AnfNodePtr bprop_arg = nullptr;
@ -677,6 +694,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
// so obtain first input grad by setting tail_type of Tail to kGradFirst. // so obtain first input grad by setting tail_type of Tail to kGradFirst.
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst); TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app})); k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
} }
@ -726,7 +744,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr k_child = nullptr; FuncGraphPtr k_child = nullptr;
{ {
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info())); TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetGrad(j, weights, forward_graph->parameters()); k_child = GetGrad(j, weights, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
} }
grad_fg->set_output(NewValueNode(k_child)); grad_fg->set_output(NewValueNode(k_child));

View File

@ -111,9 +111,11 @@ class Tail : public MetaFuncGraph {
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const; FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
private: private:
TailType tail_type_; TailType tail_type_;
bool enable_tuple_grad_;
}; };
using TailPtr = std::shared_ptr<Tail>; using TailPtr = std::shared_ptr<Tail>;
@ -145,7 +147,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params, const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args = {}); const std::vector<AnfNodePtr> &weight_args = {});
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
@ -156,7 +158,7 @@ class GradOperation : public MetaFuncGraph {
private: private:
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights); const AnfNodePtr &weights, bool enable_tuple_grad);
}; };
using GradOperationPtr = std::shared_ptr<GradOperation>; using GradOperationPtr = std::shared_ptr<GradOperation>;

View File

@ -121,6 +121,7 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() || bool broaden = value->isa<MetaTensor>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>()); (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
return abstract::FromValue(value, broaden); return abstract::FromValue(value, broaden);
} }

View File

@ -17,7 +17,7 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct neural networks. Pre-defined building blocks or computing units to construct neural networks.
""" """
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr from . import layer, loss, optim, metrics, wrap, grad, probability, sparse, dynamic_lr
from .learning_rate_schedule import * from .learning_rate_schedule import *
from .dynamic_lr import * from .dynamic_lr import *
from .cell import Cell, GraphKernel, GraphCell from .cell import Cell, GraphKernel, GraphCell
@ -26,6 +26,7 @@ from .loss import *
from .optim import * from .optim import *
from .metrics import * from .metrics import *
from .wrap import * from .wrap import *
from .grad import *
from .sparse import * from .sparse import *
@ -35,6 +36,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__) __all__.extend(optim.__all__)
__all__.extend(metrics.__all__) __all__.extend(metrics.__all__)
__all__.extend(wrap.__all__) __all__.extend(wrap.__all__)
__all__.extend(grad.__all__)
__all__.extend(sparse.__all__) __all__.extend(sparse.__all__)
__all__.extend(learning_rate_schedule.__all__) __all__.extend(learning_rate_schedule.__all__)
__all__.extend(dynamic_lr.__all__) __all__.extend(dynamic_lr.__all__)

View File

@ -33,8 +33,7 @@ from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec from ..common.api import _executor, _pynative_exec
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from ..common.tensor import Tensor from ..common.tensor import Tensor
from ..ops.functional import cast from ..ops.operations import HookBackward, Cast
from ..ops.operations import HookBackward
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout from ..parallel._tensor import _load_tensor_by_layout
@ -118,6 +117,7 @@ class Cell(Cell_):
self._bprop_debug = False self._bprop_debug = False
self.cell_type = None self.cell_type = None
self._auto_parallel_compile_and_run = False self._auto_parallel_compile_and_run = False
self.cast = Cast()
def __getstate__(self): def __getstate__(self):
base = Cell_.__getstate__(self) base = Cell_.__getstate__(self)
@ -319,9 +319,9 @@ class Cell(Cell_):
if isinstance(item, tuple): if isinstance(item, tuple):
res.append(self._cast_mixed_precision_inputs(item, dst_type)) res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif isinstance(item, float): elif isinstance(item, float):
res.append(cast(item, dst_type)) res.append(self.cast(item, dst_type))
elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}: elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
res.append(cast(item, dst_type)) res.append(self.cast(item, dst_type))
else: else:
res.append(item) res.append(item)
return tuple(res) return tuple(res)
@ -332,7 +332,7 @@ class Cell(Cell_):
if isinstance(item, tuple): if isinstance(item, tuple):
res.append(self.cast_inputs(item, dst_type)) res.append(self.cast_inputs(item, dst_type))
else: else:
res.append(cast(item, dst_type)) res.append(self.cast(item, dst_type))
return tuple(res) return tuple(res)
def do_parameter_broadcast(self): def do_parameter_broadcast(self):

View File

@ -0,0 +1,24 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Grad
Cells of grad function. Calculate the gradient of input network or function.
"""
from .cell_grad import Jvp
__all__ = ['Jvp']

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cell grad"""
from ..cell import Cell
from ...ops import composite as C
from ...ops.primitive import Primitive
from ...ops import operations as P
from ...common import dtype as mstype
from ...common.api import ms_function
class _FirstGrad(Cell):
def __init__(self, fn):
super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn
def construct(self, u, first_grad_input):
return self.first_grad_op(self.fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell):
def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn
def construct(self, u, first_grad_single_value_input):
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class Jvp(Cell):
"""
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
Args:
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
Inputs:
- **inputs** (Tensors) - The inputs to `net`.
- **v** (tuple of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
def __init__(self, fn):
super(Jvp, self).__init__()
self.fn = fn
self.oneslike = P.OnesLike()
self.first_grad = _FirstGrad(fn)
self.first_grad.add_flags(enable_tuple_grad=True)
self.first_grad_single_value = _FirstGradSingleValue(fn)
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')
self.make_tuple = Primitive('MakeTuple')
self.tuple_len = Primitive("tuple_len")
@ms_function
def construct(self, *total_input):
jvp_input = total_input[0:-1]
v = total_input[-1]
output = self.fn(*jvp_input)
if self.issubclass_(self.typeof(output), mstype.tuple_):
u = self.make_tuple()
for i in range(self.tuple_len(output)):
u = u + self.make_tuple(self.oneslike(output[i]))
else:
u = self.oneslike(output)
if self.tuple_len(jvp_input) == 1:
second_gradient_net = self.second_grad_op(self.first_grad_single_value)
gradient_output = second_gradient_net(u, jvp_input, v)
else:
second_gradient_net = self.second_grad_op(self.first_grad)
gradient_output = second_gradient_net(u, jvp_input, v)
return output, gradient_output

View File

@ -23,6 +23,7 @@ from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCel
from .grad_reducer import DistributedGradReducer from .grad_reducer import DistributedGradReducer
from ..layer.timedistributed import TimeDistributed from ..layer.timedistributed import TimeDistributed
__all__ = [ __all__ = [
"TimeDistributed", "TimeDistributed",
"ForwardValueAndGrad", "ForwardValueAndGrad",

View File

@ -20,6 +20,8 @@
from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants from mindspore.ops import _constants
from .primitive import Primitive from .primitive import Primitive
from ..common import Tensor
from ..nn.grad import Jvp
from . import operations as P from . import operations as P
from .operations import _grad_ops from .operations import _grad_ops
from .composite import GradOperation from .composite import GradOperation
@ -154,6 +156,63 @@ identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False) grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def _to_tuple(inp, arg_name):
"""
Check whether input to jvp is valid and convert the input to tuple.
"""
if isinstance(inp, list):
inp = tuple(inp)
inp_is_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
inp_is_tuple = False
for index, value in enumerate(inp):
if not isinstance(value, Tensor):
if inp_is_tuple:
raise TypeError(
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
type(value)))
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
return inp
def jvp(fn, jvp_input, v=None):
"""
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
Inputs:
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
inputs_tuple = _to_tuple(jvp_input, "input")
if v is not None:
v_tuple = _to_tuple(v, "v")
if len(v_tuple) != len(inputs_tuple):
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
if v_value.shape != inputs_value.shape:
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
else:
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
v = Tensor([1], dtype=inputs_tuple[0].dtype)
if len(inputs_tuple) != 1:
total_input = (*inputs_tuple, v_tuple)
else:
total_input = (jvp_input, v)
return Jvp(fn)(*total_input)
def grad(fn, grad_first_param=False): def grad(fn, grad_first_param=False):
""" """
A wrapper function to generate the gradient function for the input function. A wrapper function to generate the gradient function for the input function.

View File

@ -0,0 +1,242 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in graph mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
context.set_context(mode=context.GRAPH_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x**3
class SingleInputMultipleOutputNet(nn.Cell):
def construct(self, x):
return x**3, 2*x
class MultipleInputSingleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x + 3*y
class MultipleInputMultipleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_tensor_default_v():
x = Tensor(np.array([1]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([1]).astype(np.float32))
expect_grad = Tensor(np.array([3]).astype(np.float32))
primal, grad = jvp(net, x)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_default_v():
x = Tensor(np.array([1, 2]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x)
assert "The vector v can only be None if the input is " in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1], [4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "The tensor shape at the index 0 of v" in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "v is not the same size as the function inputs." in str(ex.value)

View File

@ -0,0 +1,193 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in pynative mode """
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x**3
class SingleInputMultipleOutputNet(nn.Cell):
def construct(self, x):
return x**3, 2*x
class MultipleInputSingleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x + 3*y
class MultipleInputMultipleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())