!22397 Add JVP for forward mode auto diff.
Merge pull request !22397 from LiangZhibo/fwd2
This commit is contained in:
commit
8c6d4a05fc
|
@ -438,6 +438,26 @@ REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) {
|
||||||
.def(py::init<bool>(), py::arg("reverse"));
|
.def(py::init<bool>(), py::arg("reverse"));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
||||||
|
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||||
|
if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
|
||||||
|
!((*tuple)[i]->isa<abstract::AbstractTuple>() &&
|
||||||
|
CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckTailGradFristSequence(const abstract::AbstractSequeuePtr &sequeue, bool enable_tuple_grad) {
|
||||||
|
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
||||||
|
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
|
||||||
|
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
|
||||||
|
(*sequeue)[1]->BuildType()->isa<Number>()) ||
|
||||||
|
((*sequeue)[1]->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
||||||
|
CheckSequenceAllTensor((*sequeue)[1]->cast<abstract::AbstractTuplePtr>())));
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
|
FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const {
|
||||||
MS_EXCEPTION_IF_NULL(sequeue);
|
MS_EXCEPTION_IF_NULL(sequeue);
|
||||||
|
|
||||||
|
@ -457,10 +477,7 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tail_type_ == kGradFirst) {
|
if (tail_type_ == kGradFirst) {
|
||||||
if (sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
|
||||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() ||
|
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && (*sequeue)[1]->BuildType() != nullptr &&
|
|
||||||
(*sequeue)[1]->BuildType()->isa<Number>()))) {
|
|
||||||
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
|
ret->set_output(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(1))}));
|
||||||
} else {
|
} else {
|
||||||
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
|
ret->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
|
||||||
|
@ -597,7 +614,7 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
||||||
const std::vector<AnfNodePtr> &forward_graph_params,
|
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||||
const std::vector<AnfNodePtr> &weight_args) {
|
const std::vector<AnfNodePtr> &weight_args) {
|
||||||
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
||||||
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
@ -620,13 +637,13 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh
|
||||||
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
|
auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
|
||||||
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
|
auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
|
||||||
|
|
||||||
GradByParameter(k_child, f_app, bprop, weights_node);
|
GradByParameter(k_child, f_app, bprop, weights_node, enable_tuple_grad);
|
||||||
return k_child;
|
return k_child;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do grad by the parameter of GradOperation.
|
// Do grad by the parameter of GradOperation.
|
||||||
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||||
const AnfNodePtr &weights) {
|
const AnfNodePtr &weights, bool enable_tuple_grad) {
|
||||||
MS_EXCEPTION_IF_NULL(k_child);
|
MS_EXCEPTION_IF_NULL(k_child);
|
||||||
|
|
||||||
AnfNodePtr bprop_arg = nullptr;
|
AnfNodePtr bprop_arg = nullptr;
|
||||||
|
@ -677,6 +694,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
||||||
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
|
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
|
||||||
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
|
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
|
||||||
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
|
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
|
||||||
|
tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
|
||||||
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
|
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -726,7 +744,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
||||||
FuncGraphPtr k_child = nullptr;
|
FuncGraphPtr k_child = nullptr;
|
||||||
{
|
{
|
||||||
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||||
k_child = GetGrad(j, weights, forward_graph->parameters());
|
k_child = GetGrad(j, weights, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
|
||||||
}
|
}
|
||||||
grad_fg->set_output(NewValueNode(k_child));
|
grad_fg->set_output(NewValueNode(k_child));
|
||||||
|
|
||||||
|
|
|
@ -111,9 +111,11 @@ class Tail : public MetaFuncGraph {
|
||||||
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
|
FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
|
||||||
|
|
||||||
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TailType tail_type_;
|
TailType tail_type_;
|
||||||
|
bool enable_tuple_grad_;
|
||||||
};
|
};
|
||||||
using TailPtr = std::shared_ptr<Tail>;
|
using TailPtr = std::shared_ptr<Tail>;
|
||||||
|
|
||||||
|
@ -145,7 +147,7 @@ class GradOperation : public MetaFuncGraph {
|
||||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||||
|
|
||||||
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
||||||
const std::vector<AnfNodePtr> &forward_graph_params,
|
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||||
const std::vector<AnfNodePtr> &weight_args = {});
|
const std::vector<AnfNodePtr> &weight_args = {});
|
||||||
|
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
|
@ -156,7 +158,7 @@ class GradOperation : public MetaFuncGraph {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||||
const AnfNodePtr &weights);
|
const AnfNodePtr &weights, bool enable_tuple_grad);
|
||||||
};
|
};
|
||||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||||
|
|
||||||
|
|
|
@ -121,6 +121,7 @@ AbstractBasePtr ArgsToAbstract(const ValuePtr &value) {
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
bool broaden = value->isa<MetaTensor>() ||
|
bool broaden = value->isa<MetaTensor>() ||
|
||||||
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
(MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>());
|
||||||
|
|
||||||
return abstract::FromValue(value, broaden);
|
return abstract::FromValue(value, broaden);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ Neural Networks Cells.
|
||||||
|
|
||||||
Pre-defined building blocks or computing units to construct neural networks.
|
Pre-defined building blocks or computing units to construct neural networks.
|
||||||
"""
|
"""
|
||||||
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
|
from . import layer, loss, optim, metrics, wrap, grad, probability, sparse, dynamic_lr
|
||||||
from .learning_rate_schedule import *
|
from .learning_rate_schedule import *
|
||||||
from .dynamic_lr import *
|
from .dynamic_lr import *
|
||||||
from .cell import Cell, GraphKernel, GraphCell
|
from .cell import Cell, GraphKernel, GraphCell
|
||||||
|
@ -26,6 +26,7 @@ from .loss import *
|
||||||
from .optim import *
|
from .optim import *
|
||||||
from .metrics import *
|
from .metrics import *
|
||||||
from .wrap import *
|
from .wrap import *
|
||||||
|
from .grad import *
|
||||||
from .sparse import *
|
from .sparse import *
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ __all__.extend(loss.__all__)
|
||||||
__all__.extend(optim.__all__)
|
__all__.extend(optim.__all__)
|
||||||
__all__.extend(metrics.__all__)
|
__all__.extend(metrics.__all__)
|
||||||
__all__.extend(wrap.__all__)
|
__all__.extend(wrap.__all__)
|
||||||
|
__all__.extend(grad.__all__)
|
||||||
__all__.extend(sparse.__all__)
|
__all__.extend(sparse.__all__)
|
||||||
__all__.extend(learning_rate_schedule.__all__)
|
__all__.extend(learning_rate_schedule.__all__)
|
||||||
__all__.extend(dynamic_lr.__all__)
|
__all__.extend(dynamic_lr.__all__)
|
||||||
|
|
|
@ -33,8 +33,7 @@ from ..common import dtype as mstype
|
||||||
from ..common.api import _executor, _pynative_exec
|
from ..common.api import _executor, _pynative_exec
|
||||||
from ..common.parameter import Parameter, ParameterTuple
|
from ..common.parameter import Parameter, ParameterTuple
|
||||||
from ..common.tensor import Tensor
|
from ..common.tensor import Tensor
|
||||||
from ..ops.functional import cast
|
from ..ops.operations import HookBackward, Cast
|
||||||
from ..ops.operations import HookBackward
|
|
||||||
from ..ops.primitive import Primitive
|
from ..ops.primitive import Primitive
|
||||||
from ..parallel._tensor import _load_tensor_by_layout
|
from ..parallel._tensor import _load_tensor_by_layout
|
||||||
|
|
||||||
|
@ -118,6 +117,7 @@ class Cell(Cell_):
|
||||||
self._bprop_debug = False
|
self._bprop_debug = False
|
||||||
self.cell_type = None
|
self.cell_type = None
|
||||||
self._auto_parallel_compile_and_run = False
|
self._auto_parallel_compile_and_run = False
|
||||||
|
self.cast = Cast()
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
base = Cell_.__getstate__(self)
|
base = Cell_.__getstate__(self)
|
||||||
|
@ -319,9 +319,9 @@ class Cell(Cell_):
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
res.append(self._cast_mixed_precision_inputs(item, dst_type))
|
||||||
elif isinstance(item, float):
|
elif isinstance(item, float):
|
||||||
res.append(cast(item, dst_type))
|
res.append(self.cast(item, dst_type))
|
||||||
elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
|
elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
|
||||||
res.append(cast(item, dst_type))
|
res.append(self.cast(item, dst_type))
|
||||||
else:
|
else:
|
||||||
res.append(item)
|
res.append(item)
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
@ -332,7 +332,7 @@ class Cell(Cell_):
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
res.append(self.cast_inputs(item, dst_type))
|
res.append(self.cast_inputs(item, dst_type))
|
||||||
else:
|
else:
|
||||||
res.append(cast(item, dst_type))
|
res.append(self.cast(item, dst_type))
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
def do_parameter_broadcast(self):
|
def do_parameter_broadcast(self):
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Grad
|
||||||
|
|
||||||
|
Cells of grad function. Calculate the gradient of input network or function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .cell_grad import Jvp
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Jvp']
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cell grad"""
|
||||||
|
from ..cell import Cell
|
||||||
|
from ...ops import composite as C
|
||||||
|
from ...ops.primitive import Primitive
|
||||||
|
from ...ops import operations as P
|
||||||
|
from ...common import dtype as mstype
|
||||||
|
from ...common.api import ms_function
|
||||||
|
|
||||||
|
|
||||||
|
class _FirstGrad(Cell):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super(_FirstGrad, self).__init__()
|
||||||
|
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def construct(self, u, first_grad_input):
|
||||||
|
return self.first_grad_op(self.fn)(*first_grad_input, u)
|
||||||
|
|
||||||
|
|
||||||
|
class _FirstGradSingleValue(Cell):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super(_FirstGradSingleValue, self).__init__()
|
||||||
|
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def construct(self, u, first_grad_single_value_input):
|
||||||
|
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
|
||||||
|
|
||||||
|
|
||||||
|
class Jvp(Cell):
|
||||||
|
"""
|
||||||
|
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **inputs** (Tensors) - The inputs to `net`.
|
||||||
|
- **v** (tuple of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
|
||||||
|
Must have the same size as the input of `network`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A tuple with:
|
||||||
|
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
|
||||||
|
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
|
||||||
|
"""
|
||||||
|
def __init__(self, fn):
|
||||||
|
super(Jvp, self).__init__()
|
||||||
|
self.fn = fn
|
||||||
|
self.oneslike = P.OnesLike()
|
||||||
|
self.first_grad = _FirstGrad(fn)
|
||||||
|
self.first_grad.add_flags(enable_tuple_grad=True)
|
||||||
|
self.first_grad_single_value = _FirstGradSingleValue(fn)
|
||||||
|
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
|
||||||
|
self.second_grad_op = C.GradOperation(sens_param=True)
|
||||||
|
self.issubclass_ = P.IsSubClass()
|
||||||
|
self.typeof = Primitive('typeof')
|
||||||
|
self.make_tuple = Primitive('MakeTuple')
|
||||||
|
self.tuple_len = Primitive("tuple_len")
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def construct(self, *total_input):
|
||||||
|
jvp_input = total_input[0:-1]
|
||||||
|
v = total_input[-1]
|
||||||
|
output = self.fn(*jvp_input)
|
||||||
|
|
||||||
|
if self.issubclass_(self.typeof(output), mstype.tuple_):
|
||||||
|
u = self.make_tuple()
|
||||||
|
for i in range(self.tuple_len(output)):
|
||||||
|
u = u + self.make_tuple(self.oneslike(output[i]))
|
||||||
|
else:
|
||||||
|
u = self.oneslike(output)
|
||||||
|
|
||||||
|
if self.tuple_len(jvp_input) == 1:
|
||||||
|
second_gradient_net = self.second_grad_op(self.first_grad_single_value)
|
||||||
|
gradient_output = second_gradient_net(u, jvp_input, v)
|
||||||
|
else:
|
||||||
|
second_gradient_net = self.second_grad_op(self.first_grad)
|
||||||
|
gradient_output = second_gradient_net(u, jvp_input, v)
|
||||||
|
return output, gradient_output
|
|
@ -23,6 +23,7 @@ from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCel
|
||||||
from .grad_reducer import DistributedGradReducer
|
from .grad_reducer import DistributedGradReducer
|
||||||
from ..layer.timedistributed import TimeDistributed
|
from ..layer.timedistributed import TimeDistributed
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TimeDistributed",
|
"TimeDistributed",
|
||||||
"ForwardValueAndGrad",
|
"ForwardValueAndGrad",
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||||
from mindspore.ops import _constants
|
from mindspore.ops import _constants
|
||||||
from .primitive import Primitive
|
from .primitive import Primitive
|
||||||
|
from ..common import Tensor
|
||||||
|
from ..nn.grad import Jvp
|
||||||
from . import operations as P
|
from . import operations as P
|
||||||
from .operations import _grad_ops
|
from .operations import _grad_ops
|
||||||
from .composite import GradOperation
|
from .composite import GradOperation
|
||||||
|
@ -154,6 +156,63 @@ identity = P.identity()
|
||||||
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
||||||
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
|
||||||
|
|
||||||
|
def _to_tuple(inp, arg_name):
|
||||||
|
"""
|
||||||
|
Check whether input to jvp is valid and convert the input to tuple.
|
||||||
|
"""
|
||||||
|
if isinstance(inp, list):
|
||||||
|
inp = tuple(inp)
|
||||||
|
inp_is_tuple = True
|
||||||
|
if not isinstance(inp, tuple):
|
||||||
|
inp = (inp,)
|
||||||
|
inp_is_tuple = False
|
||||||
|
for index, value in enumerate(inp):
|
||||||
|
if not isinstance(value, Tensor):
|
||||||
|
if inp_is_tuple:
|
||||||
|
raise TypeError(
|
||||||
|
"The value at the index {} of {} must be a Tensor. But it's type is {}".format(index, arg_name,
|
||||||
|
type(value)))
|
||||||
|
raise TypeError("The {} must be a Tensor. But it's type is {}".format(arg_name, type(value)))
|
||||||
|
return inp
|
||||||
|
|
||||||
|
|
||||||
|
def jvp(fn, jvp_input, v=None):
|
||||||
|
"""
|
||||||
|
Function to compute the jacobian-vector-product of the given network. jvp is equivalent to forward mode autodiff.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **network** (Cell): The network that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
|
||||||
|
- **inputs** (Tensor or tuple/list of Tensors) - The inputs to `net`.
|
||||||
|
- **v** (tuple/list of Tensors or Tensor) - The vector for which the Jacobian vector product is computed.
|
||||||
|
Must have the same size as the input of `network`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A tuple with:
|
||||||
|
net_output (Tuple(Tensor...)) - The output of `network(inputs)`.
|
||||||
|
jvp (Tuple(Tensor...)) - The result of the jacobian vector product.
|
||||||
|
"""
|
||||||
|
inputs_tuple = _to_tuple(jvp_input, "input")
|
||||||
|
if v is not None:
|
||||||
|
v_tuple = _to_tuple(v, "v")
|
||||||
|
if len(v_tuple) != len(inputs_tuple):
|
||||||
|
raise ValueError("v is not the same size as the function inputs. The length of v is {}, while the length"
|
||||||
|
"of function inputs is {}".format(len(v_tuple), len(inputs_tuple)))
|
||||||
|
for index, (v_value, inputs_value) in enumerate(zip(v_tuple, inputs_tuple)):
|
||||||
|
if v_value.shape != inputs_value.shape:
|
||||||
|
raise ValueError("The tensor shape at the index {} of v is not the same as the function inputs, it"
|
||||||
|
"should be {}, but got {}".format(index, inputs_value.shape, v_value.shape))
|
||||||
|
else:
|
||||||
|
if len(inputs_tuple) != 1 or inputs_tuple[0].shape != (1,):
|
||||||
|
raise ValueError("The vector v can only be None if the input is a Tensor with single element")
|
||||||
|
v = Tensor([1], dtype=inputs_tuple[0].dtype)
|
||||||
|
|
||||||
|
if len(inputs_tuple) != 1:
|
||||||
|
total_input = (*inputs_tuple, v_tuple)
|
||||||
|
else:
|
||||||
|
total_input = (jvp_input, v)
|
||||||
|
return Jvp(fn)(*total_input)
|
||||||
|
|
||||||
|
|
||||||
def grad(fn, grad_first_param=False):
|
def grad(fn, grad_first_param=False):
|
||||||
"""
|
"""
|
||||||
A wrapper function to generate the gradient function for the input function.
|
A wrapper function to generate the gradient function for the input function.
|
||||||
|
|
|
@ -0,0 +1,242 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test jvp in graph mode"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.functional import jvp
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleInputSingleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x**3
|
||||||
|
|
||||||
|
|
||||||
|
class SingleInputMultipleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x**3, 2*x
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleInputSingleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2*x + 3*y
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2*x, y**3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_single_output_default_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_single_output_custom_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_outputs_default_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = SingleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_outputs_custom_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = SingleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_single_output_default_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v, v))
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_single_output_custom_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = MultipleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v, v))
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = MultipleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_single_tensor_default_v():
|
||||||
|
x = Tensor(np.array([1]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([1]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([3]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x)
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_tensors_default_v():
|
||||||
|
x = Tensor(np.array([1, 2]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
with pytest.raises(ValueError) as ex:
|
||||||
|
jvp(net, x)
|
||||||
|
assert "The vector v can only be None if the input is " in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_tensors_wrong_shape_v():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1], [4]]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
with pytest.raises(ValueError) as ex:
|
||||||
|
jvp(net, x, v)
|
||||||
|
assert "The tensor shape at the index 0 of v" in str(ex.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_tensors_wrong_tuple_v():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = (Tensor(np.array([[1], [4]]).astype(np.float32)), Tensor(np.array([[1], [4]]).astype(np.float32)))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
with pytest.raises(ValueError) as ex:
|
||||||
|
jvp(net, x, v)
|
||||||
|
assert "v is not the same size as the function inputs." in str(ex.value)
|
|
@ -0,0 +1,193 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""test jvp in pynative mode """
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.functional import jvp
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
class SingleInputSingleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x**3
|
||||||
|
|
||||||
|
|
||||||
|
class SingleInputMultipleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x**3, 2*x
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleInputSingleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2*x + 3*y
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleInputMultipleOutputNet(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return 2*x, y**3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_single_output_default_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_single_output_custom_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = SingleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_outputs_default_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = SingleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_single_input_multiple_outputs_custom_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = SingleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, x, v)
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v, v))
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = MultipleInputMultipleOutputNet()
|
||||||
|
expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
|
||||||
|
expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
|
||||||
|
expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
|
||||||
|
expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||||
|
assert isinstance(primal, tuple)
|
||||||
|
assert len(primal) == 2
|
||||||
|
assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
|
||||||
|
assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
|
||||||
|
assert isinstance(grad, tuple)
|
||||||
|
assert len(grad) == 2
|
||||||
|
assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
|
||||||
|
assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_single_output_default_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
net = MultipleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v, v))
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_jvp_multiple_inputs_single_output_custom_v_pynative():
|
||||||
|
x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
|
v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
|
||||||
|
net = MultipleInputSingleOutputNet()
|
||||||
|
expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
|
||||||
|
expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
|
||||||
|
primal, grad = jvp(net, (x, y), (v1, v2))
|
||||||
|
assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
|
||||||
|
assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
|
Loading…
Reference in New Issue