forked from mindspore-Ecosystem/mindspore
add-new-interface-forward-value-and-grad
This commit is contained in:
parent
2e71163539
commit
dd36171976
|
@ -738,27 +738,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
inputs.emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
||||
<< prim->is_const_prim();
|
||||
bool is_const_input =
|
||||
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), i) != const_input_index.end();
|
||||
if (abs == nullptr || is_const_prim || is_const_input) {
|
||||
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
|
||||
ValuePtr input_value = PyAttrValue(obj);
|
||||
abs = input_value->ToAbstract();
|
||||
if (!is_const_prim && !is_const_input) {
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
abs = abs->Broaden(config);
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
||||
}
|
||||
node_abs_map_[id] = abs;
|
||||
}
|
||||
|
||||
(*args_spec_list).emplace_back(abs);
|
||||
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
|
||||
}
|
||||
|
||||
CNodePtr cnode = nullptr;
|
||||
|
@ -770,6 +750,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
return cnode;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr PynativeExecutor::CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id,
|
||||
size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto const_input_index = prim->get_const_input_indexes();
|
||||
bool have_const_input = !const_input_index.empty();
|
||||
bool is_const_prim = prim->is_const_prim();
|
||||
auto new_abs = abs;
|
||||
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value "
|
||||
<< prim->is_const_prim();
|
||||
bool is_const_input =
|
||||
have_const_input && std::find(const_input_index.begin(), const_input_index.end(), index) != const_input_index.end();
|
||||
if (abs == nullptr || is_const_prim || is_const_input) {
|
||||
MS_LOG(DEBUG) << "MakeCnode get node no in map " << id;
|
||||
ValuePtr input_value = PyAttrValue(obj);
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
new_abs = input_value->ToAbstract();
|
||||
if (!is_const_prim && !is_const_input) {
|
||||
auto config = abstract::AbstractBase::kBroadenTensorOnly;
|
||||
MS_EXCEPTION_IF_NULL(new_abs);
|
||||
new_abs = new_abs->Broaden(config);
|
||||
MS_LOG(DEBUG) << "Broaden for " << prim->ToString() << " " << config;
|
||||
}
|
||||
node_abs_map_[id] = new_abs;
|
||||
}
|
||||
return new_abs;
|
||||
}
|
||||
|
||||
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
||||
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
|
||||
MS_EXCEPTION_IF_NULL(is_find);
|
||||
|
@ -1004,6 +1012,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
|
|||
return free_param;
|
||||
}
|
||||
node = graph_info->node_map.at(obj_id).first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Get input param node " << node->ToString() << " " << obj_id;
|
||||
return node;
|
||||
}
|
||||
|
@ -2008,9 +2017,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
|
|||
top_cell_id_ = cell_id;
|
||||
in_grad_process_ = true;
|
||||
// update forward already run flag with previous top cell
|
||||
std::string input_args_id;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
input_args_id = input_args_id + GetId(args[i]) + "_";
|
||||
}
|
||||
auto pre_top_cell = GetTopCell(cell_id);
|
||||
if (pre_top_cell != nullptr) {
|
||||
pre_top_cell->forward_already_run = true;
|
||||
pre_top_cell->input_args_id = input_args_id;
|
||||
}
|
||||
auto df_builder = std::make_shared<FuncGraph>();
|
||||
auto graph_info = std::make_shared<GraphInfo>(cell_id);
|
||||
|
@ -2019,6 +2033,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
|
|||
resource->results()[pipeline::kPynativeGraphId] = graph_id_++;
|
||||
auto top_cell_info = std::make_shared<TopCellInfo>(true, resource, df_builder, cell_id);
|
||||
top_cell_info->forward_already_run = true;
|
||||
top_cell_info->input_args_id = input_args_id;
|
||||
if (!IsTopestGraph(cell_id)) {
|
||||
top_cell_info->top_cell_index = cell_graph_list_.size();
|
||||
top_cell_index_ = top_cell_info->top_cell_index;
|
||||
|
@ -2862,11 +2877,24 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args &
|
|||
}
|
||||
|
||||
py::object PynativeExecutor::CheckAlreadyRun(const py::object &cell, const py::args &args) {
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
auto top_cell = GetTopCell(cell_id);
|
||||
bool forward_run = false;
|
||||
const auto &cell_id = GetCellId(cell, args);
|
||||
// Checkout whether top cell has already run.
|
||||
std::string input_args_id;
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
input_args_id = input_args_id + GetId(args[i]) + "_";
|
||||
}
|
||||
auto top_cell = GetTopCell(cell_id);
|
||||
if (top_cell != nullptr) {
|
||||
forward_run = top_cell->forward_already_run;
|
||||
if (!top_cell->input_args_id.empty() && top_cell->input_args_id != input_args_id && top_cell->forward_already_run &&
|
||||
CheckDynamicCell(cell_id)) {
|
||||
MS_LOG(WARNING) << "The construct of running cell is dynamic and the input info of this cell has changed, "
|
||||
"forward process will run again";
|
||||
top_cell->forward_already_run = false;
|
||||
top_cell->input_args_id = input_args_id;
|
||||
} else {
|
||||
forward_run = top_cell->forward_already_run;
|
||||
}
|
||||
if (forward_run) {
|
||||
top_cell_index_ = top_cell->top_cell_index;
|
||||
}
|
||||
|
|
|
@ -107,6 +107,7 @@ class TopCellInfo {
|
|||
std::string cell_id;
|
||||
std::string sens_id;
|
||||
std::string weights_id;
|
||||
std::string input_args_id;
|
||||
};
|
||||
|
||||
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||
|
@ -209,6 +210,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *is_find);
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
|
|
|
@ -307,6 +307,23 @@ class Cell(Cell_):
|
|||
res.append(cast(item, dst_type))
|
||||
return tuple(res)
|
||||
|
||||
def do_parameter_broadcast(self):
|
||||
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
||||
if not self.parameter_broadcast_done:
|
||||
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
|
||||
self.parameter_broadcast_done = True
|
||||
|
||||
def run_construct(self, cast_inputs, kwargs):
|
||||
if self.enable_hook:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
else:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
return output
|
||||
|
||||
def __call__(self, *inputs, **kwargs):
|
||||
if self.__class__.construct is Cell.construct:
|
||||
logger.warning(f"The '{self.__class__}' does not override the method 'construct', "
|
||||
|
@ -324,11 +341,7 @@ class Cell(Cell_):
|
|||
out = self.compile_and_run(*inputs)
|
||||
return out
|
||||
|
||||
if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
|
||||
if not self.parameter_broadcast_done:
|
||||
_pynative_exec.parameter_broadcast(self, self.phase, self._auto_parallel_mode)
|
||||
self.parameter_broadcast_done = True
|
||||
|
||||
self.do_parameter_broadcast()
|
||||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
|
@ -349,14 +362,7 @@ class Cell(Cell_):
|
|||
cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
|
||||
if not cast_inputs:
|
||||
cast_inputs = inputs
|
||||
if self.enable_hook:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self._hook_construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
else:
|
||||
_pynative_exec.enter_construct(self)
|
||||
output = self.construct(*cast_inputs, **kwargs)
|
||||
_pynative_exec.leave_construct(self)
|
||||
output = self.run_construct(cast_inputs, kwargs)
|
||||
if isinstance(output, Parameter):
|
||||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
|
|
|
@ -17,7 +17,7 @@ Wrap cells for networks.
|
|||
|
||||
Use the Wrapper to combine the loss or build the training steps.
|
||||
"""
|
||||
from .cell_wrapper import TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
||||
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
|
||||
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple
|
||||
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
@ -25,6 +25,7 @@ from ..layer.timedistributed import TimeDistributed
|
|||
|
||||
__all__ = [
|
||||
"TimeDistributed",
|
||||
"ForwardValueAndGrad",
|
||||
"TrainOneStepCell",
|
||||
"WithLossCell",
|
||||
"WithGradCell",
|
||||
|
|
|
@ -13,9 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Cell_wrapper."""
|
||||
from types import FunctionType, MethodType
|
||||
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
from mindspore.context import ParallelMode
|
||||
from ...common.tensor import Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ...common.parameter import Parameter, ParameterTuple
|
||||
from ...ops import composite as C
|
||||
|
@ -174,6 +177,107 @@ class WithGradCell(Cell):
|
|||
return grads
|
||||
|
||||
|
||||
class ForwardValueAndGrad(Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
||||
Including the network and a gradient function. The resulting Cell is trained with input '\*inputs'.
|
||||
The backward graph will be created in the gradient function to calculating gradient.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. The network only supports single output.
|
||||
weights (ParameterTuple): The parameters of the training network that need to calculate the gradient
|
||||
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
||||
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
||||
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
||||
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
||||
at the same time in the form of ((gradients with respect to inputs),
|
||||
(gradients with respect to parameters)). Default: False.
|
||||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||
Default: False.
|
||||
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred through
|
||||
the location parameter or key-value pair parameter. If the value is transferred through the key-value pair
|
||||
parameter, the key must be sens.
|
||||
|
||||
Inputs:
|
||||
- **(\*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
- sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Outputs:
|
||||
- **forward value** (a scalar Tensor with shape :math:`()`) - The result of network forward running.
|
||||
- **gradients** (tuple(tensor)) - The gradients of network parameters and inputs.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU````CPU``
|
||||
|
||||
Examples:
|
||||
>>> inputs = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32))
|
||||
>>> labels = Tensor(np.ones([32]).astype(np.int32))
|
||||
>>> net = Net()
|
||||
>>> weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> #1) Using the WithLossCell existing provide
|
||||
>>> loss_net = nn.WithLossCell(net, loss_fn)
|
||||
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
||||
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
||||
>>>
|
||||
>>> #2) Using user-defined WithLossCell
|
||||
>>> class MyWithLossCell(Cell):
|
||||
... def __init__(self, backbone, loss_fn):
|
||||
... super(MyWithLossCell, self).__init__(auto_prefix=False)
|
||||
... self._backbone = backbone
|
||||
... self._loss_fn = loss_fn
|
||||
...
|
||||
... def construct(self, x, y, label):
|
||||
... out = self._backbone(x, y)
|
||||
... return self._loss_fn(out, label)
|
||||
...
|
||||
... @property
|
||||
... def backbone_network(self):
|
||||
... return self._backbone
|
||||
...
|
||||
>>> loss_net = MyWithLossCell(net, loss_fn)
|
||||
>>> forward_value_and_grad = nn.ForwardValueAndGrad(loss_net, weights=weight, get_by_list=True, sens_param=True)
|
||||
>>> loss, grads = forward_value_and_grad(inputs, labels, 1.0)
|
||||
"""
|
||||
|
||||
def __init__(self, network, weights=None, get_all=False, get_by_list=False, sens_param=False):
|
||||
super(ForwardValueAndGrad, self).__init__(auto_prefix=False)
|
||||
if not isinstance(network, (Cell, FunctionType, MethodType)):
|
||||
raise TypeError(f"The type of training network should be cell, function type or method type, "
|
||||
f"but got '{type(network)}'")
|
||||
if get_by_list and not isinstance(weights, ParameterTuple):
|
||||
raise TypeError(f"When get_by_list is set to True, the parameters of training network should be "
|
||||
f"ParameterTuple type, but got '{type(weights)}'")
|
||||
if get_by_list is not True and weights is not None:
|
||||
raise TypeError(f"When get_by_list is set to False, the parameters of training network should be "
|
||||
f"NoneType, but got '{type(weights)}'")
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = weights
|
||||
self.get_all = get_all
|
||||
self.get_by_list = get_by_list
|
||||
self.sens_param = sens_param
|
||||
self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
if self.sens_param:
|
||||
sens = inputs[-1]
|
||||
inputs = inputs[:-1]
|
||||
else:
|
||||
sens = None
|
||||
loss = self.network(*inputs)
|
||||
if self.sens_param:
|
||||
if not isinstance(sens, Tensor):
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
else:
|
||||
grads = self.grad(self.network, weights)(*inputs)
|
||||
return loss, grads
|
||||
|
||||
|
||||
class TrainOneStepCell(Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
|
|
@ -22,10 +22,10 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import Tensor, ParameterTuple
|
||||
from mindspore import amp
|
||||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell, ForwardValueAndGrad
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.basic import Flatten
|
||||
from mindspore.nn.layer.conv import Conv2d
|
||||
|
@ -33,6 +33,7 @@ from mindspore.nn.layer.normalization import BatchNorm2d
|
|||
from mindspore.nn.layer.pooling import MaxPool2d
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import Add
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
@ -399,3 +400,53 @@ def test_trainTensor_amp(num_classes=10, epoch=18, batch_size=16):
|
|||
assert (losses[-1][0].asnumpy() < 1)
|
||||
assert not losses[-1][1].asnumpy()
|
||||
assert (losses[-1][2].asnumpy() > 1)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_trainTensor_with_new_interface(num_classes=10, epoch=8, batch_size=1):
|
||||
net = resnet50(num_classes)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
net_with_criterion.set_train()
|
||||
|
||||
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
optimizer = Momentum(weights, 0.1, 0.9)
|
||||
|
||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
|
||||
losses = []
|
||||
for i in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||
).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
loss, grads = train_network(data, label, 1.0)
|
||||
grads = F.identity(grads)
|
||||
optimizer(grads)
|
||||
losses.append(loss)
|
||||
assert (losses[-1].asnumpy() < 0.8)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_big_batchSize_with_new_interface(num_classes=10, epoch=8, batch_size=338):
|
||||
net = resnet50(num_classes)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
net_with_criterion.set_train()
|
||||
|
||||
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
optimizer = Momentum(weights, 0.1, 0.9)
|
||||
|
||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
|
||||
losses = []
|
||||
for i in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||
).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
loss, grads = train_network(data, label, 1.0)
|
||||
grads = F.identity(grads)
|
||||
optimizer(grads)
|
||||
losses.append(loss)
|
||||
assert (losses[-1].asnumpy() < 0.8)
|
||||
|
|
|
@ -164,3 +164,40 @@ def test_ascend_pynative_lenet():
|
|||
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
assert loss_output.asnumpy() < 0.004
|
||||
assert loss_output.asnumpy() > 0.003
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_lenet_with_new_interface():
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
epoch_size = 20
|
||||
batch_size = 32
|
||||
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
|
||||
labels = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
|
||||
net = LeNet()
|
||||
criterion = CrossEntropyLoss()
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
net_with_criterion.set_train()
|
||||
|
||||
weights = ParameterTuple(filter(lambda x: x.requires_grad, net.get_parameters()))
|
||||
optimizer = Momentum(weights, 0.1, 0.9)
|
||||
|
||||
forward_value_and_grad = nn.ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True)
|
||||
total_time = 0
|
||||
for epoch in range(0, epoch_size):
|
||||
start_time = time.time()
|
||||
loss_output, grads = forward_value_and_grad(inputs, labels)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
cost_time = end_time - start_time
|
||||
total_time = total_time + cost_time
|
||||
|
||||
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
assert loss_output.asnumpy() < 0.005
|
||||
assert loss_output.asnumpy() > 0.003
|
||||
|
|
Loading…
Reference in New Issue