!22397 Add JVP for forward mode auto diff.

Merge pull request !22397 from LiangZhibo/fwd2
This commit is contained in:
i-robot 2021-08-26 11:28:22 +00:00 committed by Gitee
commit 8c6d4a05fc
11 changed files with 652 additions and 16 deletions

View File

@ -438,6 +438,26 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
.def(py::init<bool>(), py::arg("reverse"));
}));
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
for (size_t i = 0; i < tuple->size(); ++i) {
if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
!((*tuple)[i]->isa<abstract::AbstractTuple>() &&
CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
return false;
}
}
return true;
}
bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()) ||
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
}
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
MS_EXCEPTION_IF_NULL(sequeue);
@ -457,10 +477,7 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
}
if (tail_type_ == kGradFirst) {
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
(*sequeue)[1]->BuildType()->isa<Number>()))) {
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
} else {
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
@ -597,7 +614,7 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
}
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
@ -620,13 +637,13 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
GradByParameter(k_child, f_app, bprop, weights_node);
GradByParameter(k_child, f_app, bprop, weights_node, enable_tuple_grad);
return k_child;
}
// Do grad by the parameter of GradOperation.
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights) {
const AnfNodePtr &weights, bool enable_tuple_grad) {
MS_EXCEPTION_IF_NULL(k_child);
AnfNodePtr bprop_arg = nullptr;
@ -677,6 +694,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
}
@ -726,7 +744,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr k_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetGrad(j, weights, forward_graph->parameters());
k_child = GetGrad(j, weights, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
}
grad_fg->set_output(NewValueNode(k_child));

View File

@ -111,9 +111,11 @@ class Tail : public MetaFuncGraph {
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
private:
TailType tail_type_;
bool enable_tuple_grad_;
};
using TailPtr = std::shared_ptr<Tail>;
@ -145,7 +147,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &forward_graph_params,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args = {});
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
@ -156,7 +158,7 @@ class GradOperation : public MetaFuncGraph {
private:
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights);
const AnfNodePtr &weights, bool enable_tuple_grad);
};
using GradOperationPtr = std::shared_ptr<GradOperation>;

View File

@ -121,6 +121,7 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
bool broaden = value->isa<MetaTensor>() ||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
return abstract::FromValue(value, broaden);
}

View File

@ -17,7 +17,7 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct neural networks.
"""
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
from . import layer, loss, optim, metrics, wrap, grad, probability, sparse, dynamic_lr
from .learning_rate_schedule import *
from .dynamic_lr import *
from .cell import Cell, GraphKernel, GraphCell
@ -26,6 +26,7 @@ from .loss import *
from .optim import *
from .metrics import *
from .wrap import *
from .grad import *
from .sparse import *
@ -35,6 +36,7 @@ __all__.extend(loss.__all__)
__all__.extend(optim.__all__)
__all__.extend(metrics.__all__)
__all__.extend(wrap.__all__)
__all__.extend(grad.__all__)
__all__.extend(sparse.__all__)
__all__.extend(learning_rate_schedule.__all__)
__all__.extend(dynamic_lr.__all__)

View File

@ -33,8 +33,7 @@ from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec
from ..common.parameter import Parameter, ParameterTuple
from ..common.tensor import Tensor
from ..ops.functional import cast
from ..ops.operations import HookBackward
from ..ops.operations import HookBackward, Cast
from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout
@ -118,6 +117,7 @@ class Cell(Cell_):
self._bprop_debug = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self.cast = Cast()
def __getstate__(self):
base = Cell_.__getstate__(self)
@ -319,9 +319,9 @@ class Cell(Cell_):
if isinstance(item, tuple):
res.append(self._cast_mixed_precision_inputs(item, dst_type))
elif isinstance(item, float):
res.append(cast(item, dst_type))
res.append(self.cast(item, dst_type))
elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
res.append(cast(item, dst_type))
res.append(self.cast(item, dst_type))
else:
res.append(item)
return tuple(res)
@ -332,7 +332,7 @@ class Cell(Cell_):
if isinstance(item, tuple):
res.append(self.cast_inputs(item, dst_type))
else:
res.append(cast(item, dst_type))
res.append(self.cast(item, dst_type))
return tuple(res)
def do_parameter_broadcast(self):

View File

@ -0,0 +1,24 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Grad
Cells of grad function. Calculate the gradient of input network or function.
"""
from .cell_grad import Jvp
__all__ = ['Jvp']

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cell grad"""
from ..cell import Cell
from ...ops import composite as C
from ...ops.primitive import Primitive
from ...ops import operations as P
from ...common import dtype as mstype
from ...common.api import ms_function
class _FirstGrad(Cell):
def __init__(self, fn):
super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn
def construct(self, u, first_grad_input):
return self.first_grad_op(self.fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell):
def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn
def construct(self, u, first_grad_single_value_input):
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class Jvp(Cell):
"""
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
Args:
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
Inputs:
- **inputs** (Tensors) - The inputs to `net`.
- **v** (tuple of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
def __init__(self, fn):
super(Jvp, self).__init__()
self.fn = fn
self.oneslike = P.OnesLike()
self.first_grad = _FirstGrad(fn)
self.first_grad.add_flags(enable_tuple_grad=True)
self.first_grad_single_value = _FirstGradSingleValue(fn)
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')
self.make_tuple = Primitive('MakeTuple')
self.tuple_len = Primitive("tuple_len")
@ms_function
def construct(self, *total_input):
jvp_input = total_input[0:-1]
v = total_input[-1]
output = self.fn(*jvp_input)
if self.issubclass_(self.typeof(output), mstype.tuple_):
u = self.make_tuple()
for i in range(self.tuple_len(output)):
u = u + self.make_tuple(self.oneslike(output[i]))
else:
u = self.oneslike(output)
if self.tuple_len(jvp_input) == 1:
second_gradient_net = self.second_grad_op(self.first_grad_single_value)
gradient_output = second_gradient_net(u, jvp_input, v)
else:
second_gradient_net = self.second_grad_op(self.first_grad)
gradient_output = second_gradient_net(u, jvp_input, v)
return output, gradient_output

View File

@ -23,6 +23,7 @@ from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCel
from .grad_reducer import DistributedGradReducer
from ..layer.timedistributed import TimeDistributed
__all__ = [
"TimeDistributed",
"ForwardValueAndGrad",

View File

@ -20,6 +20,8 @@
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.ops import _constants
from .primitive import Primitive
from ..common import Tensor
from ..nn.grad import Jvp
from . import operations as P
from .operations import _grad_ops
from .composite import GradOperation
@ -154,6 +156,63 @@ identity = P.identity()
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
def _to_tuple(inp, arg_name):
"""
Check whether input to jvp is valid and convert the input to tuple.
"""
if isinstance(inp, list):
inp = tuple(inp)
inp_is_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
inp_is_tuple = False
for index, value in enumerate(inp):
if not isinstance(value, Tensor):
if inp_is_tuple:
raise TypeError(
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
type(value)))
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
return inp
def jvp(fn, jvp_input, v=None):
"""
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
Inputs:
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `network`.
Outputs:
A tuple with:
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
"""
inputs_tuple = _to_tuple(jvp_input, "input")
if v is not None:
v_tuple = _to_tuple(v, "v")
if len(v_tuple) != len(inputs_tuple):
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
if v_value.shape != inputs_value.shape:
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
else:
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
v = Tensor([1], dtype=inputs_tuple[0].dtype)
if len(inputs_tuple) != 1:
total_input = (*inputs_tuple, v_tuple)
else:
total_input = (jvp_input, v)
return Jvp(fn)(*total_input)
def grad(fn, grad_first_param=False):
"""
A wrapper function to generate the gradient function for the input function.

View File

@ -0,0 +1,242 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in graph mode"""
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
context.set_context(mode=context.GRAPH_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x**3
class SingleInputMultipleOutputNet(nn.Cell):
def construct(self, x):
return x**3, 2*x
class MultipleInputSingleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x + 3*y
class MultipleInputMultipleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_tensor_default_v():
x = Tensor(np.array([1]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([1]).astype(np.float32))
expect_grad = Tensor(np.array([3]).astype(np.float32))
primal, grad = jvp(net, x)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_default_v():
x = Tensor(np.array([1, 2]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x)
assert "The vector v can only be None if the input is " in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1], [4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "The tensor shape at the index 0 of v" in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
net = SingleInputSingleOutputNet()
with pytest.raises(ValueError) as ex:
jvp(net, x, v)
assert "v is not the same size as the function inputs." in str(ex.value)

View File

@ -0,0 +1,193 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test jvp in pynative mode """
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.functional import jvp
context.set_context(mode=context.PYNATIVE_MODE)
class SingleInputSingleOutputNet(nn.Cell):
def construct(self, x):
return x**3
class SingleInputMultipleOutputNet(nn.Cell):
def construct(self, x):
return x**3, 2*x
class MultipleInputSingleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x + 3*y
class MultipleInputMultipleOutputNet(nn.Cell):
def construct(self, x, y):
return 2*x, y**3
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_single_output_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputSingleOutputNet()
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_single_input_multiple_outputs_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = SingleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
primal, grad = jvp(net, x, v)
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputMultipleOutputNet()
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert isinstance(primal, tuple)
assert len(primal) == 2
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
assert isinstance(grad, tuple)
assert len(grad) == 2
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_default_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v, v))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jvp_multiple_inputs_single_output_custom_v_pynative():
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
net = MultipleInputSingleOutputNet()
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
primal, grad = jvp(net, (x, y), (v1, v2))
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())