forked from mindspore-Ecosystem/mindspore
modify file and clean src_thor`
This commit is contained in:
parent
a90ee15937
commit
f127db121b
|
@ -1,132 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
||||
According to different context, change the iter of dataset, to use the same for loop in different context.
|
||||
|
||||
Note:
|
||||
The iter of DatasetHelper will give one epoch data.
|
||||
|
||||
Args:
|
||||
dataset (DataSet): The dataset.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
|
||||
Default: True.
|
||||
iter_first_order (int): The iteration of first-order subgraph.
|
||||
Default: 1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
>>> for inputs in dataset_helper:
|
||||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
|
||||
Validator.check_bool(dataset_sink_mode)
|
||||
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter.__iter__()
|
||||
|
||||
# A temp solution for loop sink. Delete later
|
||||
def types_shapes(self):
|
||||
"""Get the types and shapes from dataset on current config."""
|
||||
return self.iter.types_shapes()
|
||||
|
||||
def loop_size(self):
|
||||
"""Get loop_size for every iteration."""
|
||||
return self.iter.loop_size
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset help"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.loop_size = 1
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.loop_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
||||
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
||||
|
||||
def __iter__(self):
|
||||
self.ind = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.ind >= self.loop_count:
|
||||
raise StopIteration()
|
||||
self.ind += 1
|
||||
return self.op()
|
||||
|
||||
def types_shapes(self):
|
||||
return self.dataset_types, self.dataset_shapes
|
||||
|
||||
def get_loop_count(self, dataset):
|
||||
loop_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__
|
||||
if dataset.get_dataset_size() % loop_size != 0:
|
||||
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
|
||||
f'loop_size {loop_size} are not matched.')
|
||||
loop_count = int(dataset.get_dataset_size() / loop_size)
|
||||
return loop_count
|
||||
|
||||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context (device_target=Ascend)"""
|
||||
|
||||
def __init__(self, dataset, iter_first_order):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
self.loop_count = int(dataset.get_dataset_size() / loop_size * 2)
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||
# times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
def op():
|
||||
return tuple()
|
||||
|
||||
self.op = op
|
|
@ -1,132 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluation metric."""
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.communication.management import GlobalComm
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ClassifyCorrectCell(nn.Cell):
|
||||
r"""
|
||||
Cell that returns correct count of the prediction in classification network.
|
||||
This Cell accepts a network as arguments.
|
||||
It returns orrect count of the prediction to calculate the metrics.
|
||||
|
||||
Args:
|
||||
network (Cell): The network Cell.
|
||||
|
||||
Inputs:
|
||||
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||
|
||||
Outputs:
|
||||
Tuple, containing a scalar correct count of the prediction
|
||||
|
||||
Examples:
|
||||
>>> # For a defined network Net without loss function
|
||||
>>> net = Net()
|
||||
>>> eval_net = nn.ClassifyCorrectCell(net)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
|
||||
self._network = network
|
||||
self.argmax = P.Argmax()
|
||||
self.equal = P.Equal()
|
||||
self.cast = P.Cast()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||
|
||||
def construct(self, data, label):
|
||||
outputs = self._network(data)
|
||||
y_pred = self.argmax(outputs)
|
||||
y_pred = self.cast(y_pred, mstype.int32)
|
||||
y_correct = self.equal(y_pred, label)
|
||||
y_correct = self.cast(y_correct, mstype.float32)
|
||||
y_correct = self.reduce_sum(y_correct)
|
||||
total_correct = self.allreduce(y_correct)
|
||||
return (total_correct,)
|
||||
|
||||
|
||||
class DistAccuracy(nn.Metric):
|
||||
r"""
|
||||
Calculates the accuracy for classification data in distributed mode.
|
||||
The accuracy class creates two local variables, correct number and total number that are used to compute the
|
||||
frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
|
||||
idempotent operation that simply divides correct number by total number.
|
||||
|
||||
.. math::
|
||||
|
||||
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
|
||||
|
||||
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
|
||||
|
||||
Args:
|
||||
batch_size (int): eval batch size.
|
||||
device_num (int): device number to eval.
|
||||
Examples:
|
||||
>>> y_correct = Tensor(np.array([20]))
|
||||
>>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
|
||||
>>> metric.clear()
|
||||
>>> metric.update(y_correct)
|
||||
>>> accuracy = metric.eval()
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size, device_num):
|
||||
super(DistAccuracy, self).__init__()
|
||||
self.clear()
|
||||
self.batch_size = batch_size
|
||||
self.device_num = device_num
|
||||
|
||||
def clear(self):
|
||||
"""Clears the internal evaluation result."""
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
|
||||
|
||||
Args:
|
||||
inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
|
||||
`y_correct` is the right prediction count that gathered from all devices
|
||||
it's a scalar in float type
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of the input is not 1.
|
||||
"""
|
||||
|
||||
if len(inputs) != 1:
|
||||
raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
|
||||
y_correct = self._convert_data(inputs[0])
|
||||
self._correct_num += y_correct
|
||||
self._total_num += self.batch_size * self.device_num
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
Computes the accuracy.
|
||||
|
||||
Returns:
|
||||
Float, the computed result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the sample size is 0.
|
||||
"""
|
||||
|
||||
if self._total_num == 0:
|
||||
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
|
||||
return self._correct_num / self._total_num
|
|
@ -1,736 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model."""
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
from mindspore._checkparam import check_input_data, check_output_data, Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
Convert from numpy type to tensor type.
|
||||
|
||||
Args:
|
||||
types (list): Numpy type list of element in dataset.
|
||||
|
||||
Returns:
|
||||
list, list of element in dataset.
|
||||
"""
|
||||
ms_types = []
|
||||
for np_type in types:
|
||||
ms_type = pytype_to_dtype(np_type)
|
||||
ms_types.append(ms_type)
|
||||
return ms_types
|
||||
|
||||
|
||||
def _get_types_and_shapes(dataset):
|
||||
"""Get dataset types and shapes."""
|
||||
dataset_types = _convert_type(dataset.output_types())
|
||||
dataset_shapes = dataset.output_shapes()
|
||||
return dataset_types, dataset_shapes
|
||||
|
||||
|
||||
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||
"""Initialize and execute the dataset graph."""
|
||||
batch_size = exec_dataset.get_batch_size()
|
||||
input_indexs = exec_dataset.input_indexs
|
||||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
dataset_shapes,
|
||||
input_indexs,
|
||||
phase=phase,
|
||||
need_run=False)
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
High-Level API for Training or Testing.
|
||||
|
||||
`Model` groups layers into an object with training and inference features.
|
||||
|
||||
Args:
|
||||
network (Cell): The training or testing network.
|
||||
loss_fn (Cell): Objective function, if loss_fn is None, the
|
||||
network should contain the logic of loss and grads calculation, and the logic
|
||||
of parallel if needed. Default: None.
|
||||
optimizer (Cell): Optimizer for updating the weights. Default: None.
|
||||
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
|
||||
training and testing. eg: {'accuracy', 'recall'}. Default: None.
|
||||
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
||||
`eval_network`. Default: None.
|
||||
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
|
||||
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
|
||||
elements, representing the positions of loss value, predict value and label, the loss
|
||||
value would be passed to `Loss` metric, predict value and label would be passed to other
|
||||
metric. Default: None.
|
||||
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
|
||||
precision training. Supports [O0, O2]. Default: "O0".
|
||||
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
>>> dataset = get_dataset()
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._process_amp_args(kwargs)
|
||||
self._parallel_mode = _get_parallel_mode()
|
||||
self._device_number = _get_device_num()
|
||||
self._global_rank = _get_global_rank()
|
||||
self._parameter_broadcast = _get_parameter_broadcast()
|
||||
self._frequency = frequency
|
||||
self._stop_epoch = stop_epoch
|
||||
self._has_do_dataset_init = False
|
||||
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level == "O0":
|
||||
self._keep_bn_fp32 = False
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
self._loss_scale_manager = kwargs['loss_scale_manager']
|
||||
self._loss_scale_manager_set = True
|
||||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
|
||||
def _build_train_network(self):
|
||||
"""Build train network"""
|
||||
network = self._network
|
||||
if self._optimizer:
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
"""Build the network for evaluation."""
|
||||
self._metric_fns = get_metrics(metrics)
|
||||
if not self._metric_fns:
|
||||
return
|
||||
|
||||
if eval_network is not None:
|
||||
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
|
||||
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
|
||||
must be three. But got {}".format(eval_indexes))
|
||||
|
||||
self._eval_network = eval_network
|
||||
self._eval_indexes = eval_indexes
|
||||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
for metric in self._metric_fns.values():
|
||||
metric.clear()
|
||||
|
||||
def _update_metrics(self, outputs):
|
||||
"""Update metrics local values."""
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError("The `outputs` is not tuple.")
|
||||
|
||||
if self._eval_indexes is not None and len(outputs) < 3:
|
||||
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
|
||||
but got {}".format(len(outputs)))
|
||||
|
||||
for metric in self._metric_fns.values():
|
||||
if self._eval_indexes is None:
|
||||
metric.update(*outputs)
|
||||
else:
|
||||
if isinstance(metric, Loss):
|
||||
metric.update(outputs[self._eval_indexes[0]])
|
||||
else:
|
||||
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
|
||||
|
||||
def _get_metrics(self):
|
||||
"""Get metrics local values."""
|
||||
metrics = dict()
|
||||
for key, value in self._metric_fns.items():
|
||||
metrics[key] = value.eval()
|
||||
return metrics
|
||||
|
||||
def _get_scaling_sens(self):
|
||||
"""get the scaling sens"""
|
||||
scaling_sens = 1
|
||||
if self._loss_scale_manager is not None:
|
||||
scaling_sens = self._loss_scale_manager.get_loss_scale()
|
||||
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
|
||||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1):
|
||||
"""Initializes dataset."""
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
|
||||
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
def init(self, train_dataset=None, valid_dataset=None):
|
||||
"""
|
||||
Initializes compute graphs and data graphs with sink mode.
|
||||
|
||||
Note:
|
||||
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
|
||||
initialized. Default: None.
|
||||
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
|
||||
be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> train_dataset = get_train_dataset()
|
||||
>>> valid_dataset = get_valid_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
||||
>>> model.init(train_dataset, valid_dataset)
|
||||
>>> model.train(2, train_dataset)
|
||||
>>> model.eval(valid_dataset)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||
|
||||
if not train_dataset and not valid_dataset:
|
||||
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
if train_dataset:
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
self._train_network.set_train()
|
||||
self._train_network.phase = 'train'
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
iter_first_order = self._frequency - 1
|
||||
iter_second_order = 1
|
||||
train_dataset.__loop_size__ = iter_second_order
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
iter_first_order=iter_first_order)
|
||||
self._train_network = train_network
|
||||
switch_branch_one = True
|
||||
index = 0
|
||||
for inputs in train_dataset_helper:
|
||||
if switch_branch_one:
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
else:
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
self._train_network.phase = 'train1'
|
||||
if not self._has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
self._has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
self._train_network.compile(*inputs)
|
||||
if index >= 1:
|
||||
break
|
||||
index += 1
|
||||
|
||||
if valid_dataset:
|
||||
if not self._metric_fns:
|
||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
self._eval_network.add_flags_recursive(thor=False)
|
||||
for inputs in valid_dataset_helper:
|
||||
self._eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Training.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) will
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
"""
|
||||
epoch = Validator.check_positive_int(epoch)
|
||||
self._train_network.set_train()
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
cb_params.mode = "train"
|
||||
cb_params.loss_fn = self._loss_fn
|
||||
cb_params.optimizer = self._optimizer
|
||||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
iter_first_order = self._frequency - 1
|
||||
iter_second_order = 1
|
||||
train_dataset.__loop_size__ = iter_second_order
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
iter_first_order=iter_first_order)
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
loop_size = dataset_helper.loop_size()
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
switch_branch_one = True
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += loop_size
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
else:
|
||||
cb_params.cur_step_num += iter_first_order
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
self._train_network.phase = 'train1'
|
||||
if not self._has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
self._has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=False)
|
||||
cb_params.cur_step_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
len_element = len(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
_, overflow, _ = outputs
|
||||
overflow = np.all(overflow.asnumpy())
|
||||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Training API where the iteration is controlled by python front-end.
|
||||
|
||||
When setting pynative mode, the training process will be performed with dataset not sink.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
|
||||
operation in dataset processing. Otherwise, errors could occur since the amount of data
|
||||
is not the amount training requires.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
repeat_count = train_dataset.get_repeat_count()
|
||||
if epoch != repeat_count and dataset_sink_mode:
|
||||
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
cb_params.eval_network = self._eval_network
|
||||
list_callback.begin(run_context)
|
||||
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
outputs = self._eval_network(*inputs)
|
||||
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
|
||||
return metrics
|
||||
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (ListCallback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
dataset_helper, _ = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=False)
|
||||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
return metrics
|
||||
|
||||
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Evaluation API where the iteration is controlled by python front-end.
|
||||
|
||||
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed
|
||||
while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
||||
self._clear_metrics()
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
Data could be single tensor, or list of tensor, tuple of tensor.
|
||||
|
||||
Note:
|
||||
Batch data should be put together in one tensor.
|
||||
|
||||
Args:
|
||||
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
|
||||
|
||||
Returns:
|
||||
Tensor, array(s) of predictions.
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
||||
>>> model = Model(Net())
|
||||
>>> model.predict(input_data)
|
||||
"""
|
||||
self._predict_network.set_train(False)
|
||||
check_input_data(*predict_data, data_class=Tensor)
|
||||
result = self._predict_network(*predict_data)
|
||||
|
||||
check_output_data(result)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["Model"]
|
|
@ -1,618 +0,0 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import math
|
||||
import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
fan_in = in_channel * kernel_size * kernel_size
|
||||
scale = 1.0
|
||||
scale /= max(1., fan_in)
|
||||
stddev = (scale ** 0.5) / .87962566103423978
|
||||
if config.net_name == "resnet152":
|
||||
stddev = (scale ** 0.5)
|
||||
mu, sigma = 0, stddev
|
||||
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
|
||||
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
|
||||
return Tensor(weight, dtype=mstype.float32)
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
neg_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
neg_slope = param
|
||||
else:
|
||||
raise ValueError("neg_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + neg_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Unsupported mode {}, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
if config.net_name == "resnet152":
|
||||
weight = _weight_variable(weight_shape)
|
||||
if res_base:
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=1, pad_mode='pad', weight_init=weight)
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
if config.net_name == "resnet152":
|
||||
weight = _weight_variable(weight_shape)
|
||||
if res_base:
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='pad', weight_init=weight)
|
||||
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
if config.net_name == "resnet152":
|
||||
weight = _weight_variable(weight_shape)
|
||||
if res_base:
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel, res_base=False):
|
||||
if res_base:
|
||||
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel, use_se=False):
|
||||
if use_se:
|
||||
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
|
||||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
if config.net_name == "resnet152":
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1,
|
||||
use_se=False, se_block=False):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.use_se = use_se
|
||||
self.se_block = se_block
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
|
||||
self.bn1 = _bn(channel)
|
||||
if self.use_se and self.stride != 1:
|
||||
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
|
||||
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
|
||||
else:
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||
self.bn3 = _bn(out_channel)
|
||||
if config.optimizer == "Thor" or config.net_name == "resnet152":
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
if self.se_block:
|
||||
self.se_global_pool = ops.ReduceMean(keep_dims=False)
|
||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
|
||||
self.se_sigmoid = nn.Sigmoid()
|
||||
self.se_mul = ops.Mul()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
if self.use_se:
|
||||
if stride == 1:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
|
||||
stride, use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
|
||||
_conv1x1(in_channel, out_channel, 1,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
if self.use_se and self.stride != 1:
|
||||
out = self.e2(out)
|
||||
else:
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.se_block:
|
||||
out_se = out
|
||||
out = self.se_global_pool(out, (2, 3))
|
||||
out = self.se_dense_0(out)
|
||||
out = self.relu(out)
|
||||
out = self.se_dense_1(out)
|
||||
out = self.se_sigmoid(out)
|
||||
out = ops.reshape(out, ops.shape(out) + (1, 1))
|
||||
out = self.se_mul(out, out_se)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResidualBlockBase(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
res_base (bool): Enable parameter setting of resnet18. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlockBase(3, 256, stride=2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1,
|
||||
use_se=False,
|
||||
se_block=False,
|
||||
res_base=True):
|
||||
super(ResidualBlockBase, self).__init__()
|
||||
self.res_base = res_base
|
||||
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
|
||||
self.bn1d = _bn(out_channel)
|
||||
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
|
||||
self.bn2d = _bn(out_channel)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
|
||||
self.down_sample_layer = None
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
use_se=use_se, res_base=self.res_base),
|
||||
_bn(out_channel, res_base)])
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1d(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2d(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
|
||||
res_base (bool): Enable parameter setting of resnet18. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes,
|
||||
use_se=False,
|
||||
res_base=False):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
self.use_se = use_se
|
||||
self.res_base = res_base
|
||||
self.se_block = False
|
||||
|
||||
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
|
||||
self.bn1 = _bn(64, self.res_base)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
if self.res_base:
|
||||
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
|
||||
else:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0],
|
||||
use_se=self.use_se)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1],
|
||||
use_se=self.use_se)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
|
||||
self.mean = ops.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
if se_block:
|
||||
for _ in range(1, layer_num - 1):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
|
||||
layers.append(resnet_block)
|
||||
else:
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
if self.use_se:
|
||||
x = self.conv1_0(x)
|
||||
x = self.bn1_0(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
if self.res_base:
|
||||
x = self.pad(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def resnet18(class_num=10):
|
||||
"""
|
||||
Get ResNet18 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet18 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet18(10)
|
||||
"""
|
||||
return ResNet(ResidualBlockBase,
|
||||
[2, 2, 2, 2],
|
||||
[64, 64, 128, 256],
|
||||
[64, 128, 256, 512],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
res_base=True)
|
||||
|
||||
|
||||
def resnet34(class_num=10):
|
||||
"""
|
||||
Get ResNet34 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet34 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet18(10)
|
||||
"""
|
||||
return ResNet(ResidualBlockBase,
|
||||
[3, 4, 6, 3],
|
||||
[64, 64, 128, 256],
|
||||
[64, 128, 256, 512],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
res_base=True)
|
||||
|
||||
|
||||
def resnet50(class_num=10):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
def se_resnet50(class_num=1001):
|
||||
"""
|
||||
Get SE-ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of SE-ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = se-resnet50(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
use_se=True)
|
||||
|
||||
|
||||
def resnet101(class_num=1001):
|
||||
"""
|
||||
Get ResNet101 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet101 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet101(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
def resnet152(class_num=1001):
|
||||
"""
|
||||
Get ResNet152 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet152 neural network.
|
||||
|
||||
Examples:
|
||||
# >>> net = resnet152(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 8, 36, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -1,508 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""THOR"""
|
||||
import numpy as np
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
|
||||
from mindspore.nn.wrap import DistributedGradReducer
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
# Enumerates types of Layer
|
||||
Other = -1
|
||||
Conv = 1
|
||||
FC = 2
|
||||
Embedding = 3
|
||||
LayerNorm = 4
|
||||
BatchNorm = 5
|
||||
|
||||
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
C0 = 16
|
||||
|
||||
|
||||
def caculate_device_shape(matrix_dim, channel, is_A):
|
||||
ll = (0)
|
||||
if is_A:
|
||||
if channel // C0 == 0:
|
||||
matrix_dim = (matrix_dim / channel) * C0
|
||||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
else:
|
||||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
return ll
|
||||
|
||||
|
||||
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
|
||||
"""get matmul shape"""
|
||||
split_dimA = split_dim
|
||||
split_dimG = split_dim
|
||||
if matrix_A_dim % split_dim == 0:
|
||||
batch_w = matrix_A_dim // split_dim
|
||||
else:
|
||||
if matrix_A_dim < split_dim:
|
||||
batch_w = 1
|
||||
split_dimA = matrix_A_dim
|
||||
else:
|
||||
batch_w = matrix_A_dim // split_dim + 1
|
||||
|
||||
if matrix_G_dim % split_dim == 0:
|
||||
batch_h = matrix_G_dim // split_dim
|
||||
else:
|
||||
if matrix_G_dim < split_dim:
|
||||
batch_h = 1
|
||||
split_dimG = matrix_G_dim
|
||||
else:
|
||||
batch_h = matrix_G_dim // split_dim + 1
|
||||
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
|
||||
matrix_G_shape = (batch_h, split_dimG, split_dimG)
|
||||
return matrix_A_shape, matrix_G_shape
|
||||
|
||||
|
||||
def find_net_layertype_recur(net, layertype_map):
|
||||
"""get net layer type recursively."""
|
||||
cells = net.name_cells()
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == net:
|
||||
continue
|
||||
elif isinstance(subcell, Conv2dThor):
|
||||
layertype_map.append(Conv)
|
||||
elif isinstance(subcell, DenseThor):
|
||||
layertype_map.append(FC)
|
||||
elif isinstance(subcell, EmbeddingThor):
|
||||
layertype_map.append(Embedding)
|
||||
elif isinstance(subcell, nn.LayerNorm):
|
||||
layertype_map.append(LayerNorm)
|
||||
elif isinstance(subcell, nn.BatchNorm2d):
|
||||
layertype_map.append(BatchNorm)
|
||||
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
|
||||
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
|
||||
layertype_map.append(Other)
|
||||
else:
|
||||
find_net_layertype_recur(subcell, layertype_map)
|
||||
|
||||
|
||||
def get_net_layertype_mask(net):
|
||||
layertype_map = []
|
||||
find_net_layertype_recur(net, layertype_map)
|
||||
return layertype_map
|
||||
|
||||
|
||||
def get_layer_counter(layer_type, layer_counter, params, idx):
|
||||
"""get layer counter"""
|
||||
if layer_type in [Conv, FC, LayerNorm, BatchNorm]:
|
||||
if layer_type in [LayerNorm, BatchNorm]:
|
||||
if "beta" in params[idx].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
if "bias" in params[idx].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
layer_counter = layer_counter + 1
|
||||
return layer_counter
|
||||
|
||||
|
||||
def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
|
||||
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
|
||||
context.set_context(max_call_depth=10000)
|
||||
ConvertNetUtils().convert_to_thor_net(net)
|
||||
|
||||
return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
|
||||
split_indices=split_indices)
|
||||
|
||||
|
||||
class THOR_Ascend(Optimizer):
|
||||
"""THOR"""
|
||||
|
||||
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
|
||||
decay_filter=lambda x: x.name not in [], split_indices=None):
|
||||
params = filter(lambda x: x.requires_grad, net.get_parameters())
|
||||
super(THOR_Ascend, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.net = net
|
||||
self.matrix_A_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
|
||||
self.matrix_G_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
|
||||
self.A_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
|
||||
self.G_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
|
||||
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
|
||||
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
|
||||
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
|
||||
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
|
||||
self.C0 = 16
|
||||
self.matrix_A_dim = ()
|
||||
self.padA_flag = ()
|
||||
self.device_shape_pad_flag = ()
|
||||
self.diag_block_dim = 128
|
||||
self.matrix_A = ()
|
||||
self.matrix_G = ()
|
||||
print("matrix_a_cov len is", len(self.matrix_A_cov))
|
||||
self.thor_layer_count = 0
|
||||
self.conv_layer_count = 0
|
||||
self.weight_fim_idx_map = ()
|
||||
self.weight_conv_idx_map = ()
|
||||
self.weight_layerType_idx_map = ()
|
||||
self._process_matrix_init_and_weight_idx_map(self.net)
|
||||
|
||||
self.matrix_A = ParameterTuple(self.matrix_A)
|
||||
self.matrix_G = ParameterTuple(self.matrix_G)
|
||||
self.matrix_max_inv = ()
|
||||
for i in range(len(self.matrix_A)):
|
||||
self.matrix_max_inv = self.matrix_max_inv + (
|
||||
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
|
||||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.thor = True
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.damping = damping
|
||||
self.gather = P.GatherV2()
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.batch_size = Tensor(batch_size, mstype.float32)
|
||||
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
|
||||
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
|
||||
self.axis = 0
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
self.cast = P.Cast()
|
||||
self.eye = P.Eye()
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.fused_abs_max2 = P.CusFusedAbsMax1()
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.slice = P.Slice()
|
||||
self.expand = P.ExpandDims()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
if self.is_distributed:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.split_indices = split_indices
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
|
||||
self.grad_reducer_Amax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=2)
|
||||
self.grad_reducer_Gmax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=4)
|
||||
self.grad_reducer_A = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=6)
|
||||
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
|
||||
|
||||
def _process_matrix_init_and_weight_idx_map(self, net):
|
||||
"""process matrix init shape, and get weight idx map"""
|
||||
layer_type_map = get_net_layertype_mask(net)
|
||||
layer_counter = 0
|
||||
for idx in range(len(self.params)):
|
||||
layer_type = layer_type_map[layer_counter]
|
||||
weight = self.params[idx]
|
||||
weight_shape = self.shape(weight)
|
||||
if layer_type == Conv and "bias" not in self.params[idx].name.lower():
|
||||
in_channels = weight_shape[1]
|
||||
out_channels = weight_shape[0]
|
||||
matrix_A_dim = in_channels * weight_shape[2] * weight_shape[3]
|
||||
matrix_G_dim = out_channels
|
||||
matrix_A_device_shape, matrix_A_device_dim = caculate_device_shape(matrix_A_dim, in_channels, True)
|
||||
matrix_G_device_shape, matrix_G_device_dim = caculate_device_shape(matrix_G_dim, in_channels, False)
|
||||
matrix_A_inv = Parameter(
|
||||
Tensor(np.reshape(np.identity(matrix_A_device_dim).astype(np.float16), matrix_A_device_shape)),
|
||||
name='matrix_A_inv_' + str(self.thor_layer_count), requires_grad=False)
|
||||
matrix_G_inv = Parameter(
|
||||
Tensor(np.reshape(np.identity(matrix_G_device_dim).astype(np.float16), matrix_G_device_shape)),
|
||||
name="matrix_G_inv_" + str(self.thor_layer_count), requires_grad=False)
|
||||
self.matrix_A = self.matrix_A + (matrix_A_inv,)
|
||||
self.matrix_G = self.matrix_G + (matrix_G_inv,)
|
||||
self.matrix_A_dim = self.matrix_A_dim + (matrix_A_dim,)
|
||||
padA_flag = False
|
||||
if (matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != matrix_A_dim \
|
||||
and matrix_A_dim > self.diag_block_dim:
|
||||
padA_flag = True
|
||||
self.padA_flag = self.padA_flag + (padA_flag,)
|
||||
device_shape_pad_flag = False
|
||||
if matrix_A_dim != matrix_A_device_dim:
|
||||
device_shape_pad_flag = True
|
||||
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
|
||||
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
|
||||
out_channels = weight_shape[0]
|
||||
if out_channels == 1001:
|
||||
fc_matrix_A = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
|
||||
name='matrix_A_inv_' + str(self.thor_layer_count),
|
||||
requires_grad=False)
|
||||
fc_matrix_G = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
|
||||
name="matrix_G_inv_" + str(self.thor_layer_count),
|
||||
requires_grad=False)
|
||||
self.matrix_A = self.matrix_A + (fc_matrix_A,)
|
||||
self.matrix_G = self.matrix_G + (fc_matrix_G,)
|
||||
|
||||
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
|
||||
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (layer_type,)
|
||||
self.thor_layer_count = self.thor_layer_count + 1
|
||||
if layer_type == Conv:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
|
||||
self.conv_layer_count = self.conv_layer_count + 1
|
||||
else:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
else:
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (Other,)
|
||||
# bert.cls1.output_bias: not a network layer, only a trainable param
|
||||
if "output_bias" not in self.params[idx].name.lower():
|
||||
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
|
||||
|
||||
def _get_Ainv_Ginv_Amax_Gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
|
||||
matrix_a_max_allreduce, matrix_g_max_allreduce):
|
||||
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
|
||||
for i in range(0, 160, 3):
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layerType_idx_map[i]
|
||||
if layer_type in [Conv, FC, Embedding]:
|
||||
g = gradients[i]
|
||||
matrix_A = self.matrix_A_cov[thor_layer_count]
|
||||
matrix_G = self.matrix_G_cov[thor_layer_count]
|
||||
matrix_A = F.depend(matrix_A, g)
|
||||
matrix_G = F.depend(matrix_G, g)
|
||||
A_shape = self.shape(matrix_A)
|
||||
A_eye = self.eye(A_shape[0], A_shape[0], mstype.float32)
|
||||
G_shape = self.shape(matrix_G)
|
||||
G_eye = self.eye(G_shape[0], G_shape[0], mstype.float32)
|
||||
if layer_type == Conv:
|
||||
A_normalizer = self.A_normalizer[conv_layer_count]
|
||||
G_normalizer = self.G_normalizer[conv_layer_count]
|
||||
A_normalizer = F.depend(A_normalizer, g)
|
||||
G_normalizer = F.depend(G_normalizer, g)
|
||||
dampingA = self.mul(damping_step, self.batch_size / A_normalizer)
|
||||
dampingG = self.mul(damping_step, self.batch_size / G_normalizer)
|
||||
dampingA = self.sqrt(dampingA)
|
||||
matrix_A = matrix_A + dampingA * A_eye
|
||||
matrix_A_inv = self.cholesky(matrix_A)
|
||||
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
|
||||
A_max = P.CusFusedAbsMax1([self.matrix_A_dim[conv_layer_count],
|
||||
self.matrix_A_dim[conv_layer_count]])(matrix_A_inv)
|
||||
A_max = self.fused_abs_max2(A_max)
|
||||
matrix_A_inv = self.matrix_combine(matrix_A_inv)
|
||||
if self.padA_flag[conv_layer_count]:
|
||||
matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim[conv_layer_count],
|
||||
self.matrix_A_dim[conv_layer_count]))
|
||||
if self.device_shape_pad_flag[conv_layer_count]:
|
||||
weight = self.params[i]
|
||||
weight_shape = self.shape(weight)
|
||||
kernel_hw = weight_shape[2] * weight_shape[3]
|
||||
in_channels = weight_shape[1]
|
||||
matrix_A_inv = self.reshape(matrix_A_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
|
||||
matrix_A_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
|
||||
(0, self.C0 - in_channels)))(matrix_A_inv)
|
||||
matrix_A_inv_shape = self.shape(self.matrix_A[thor_layer_count])
|
||||
matrix_A_device_temp_shape = (matrix_A_inv_shape[0], matrix_A_inv_shape[2],
|
||||
matrix_A_inv_shape[1], matrix_A_inv_shape[3])
|
||||
matrix_A_inv = self.reshape(matrix_A_inv, matrix_A_device_temp_shape)
|
||||
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
|
||||
|
||||
dampingG = self.sqrt(dampingG)
|
||||
matrix_G = self.mul(matrix_G, self.loss_scale)
|
||||
matrix_G = self.mul(matrix_G, self.batch_size_scale)
|
||||
matrix_G = matrix_G + dampingG * G_eye
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
G_max = self.fused_abs_max2(matrix_G_inv)
|
||||
G_max = self.fused_abs_max2(G_max)
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv_shape = self.shape(self.matrix_G[thor_layer_count])
|
||||
matrix_G_device_temp_shape = (matrix_G_inv_shape[0], matrix_G_inv_shape[2],
|
||||
matrix_G_inv_shape[1], matrix_G_inv_shape[3])
|
||||
matrix_G_inv = self.reshape(matrix_G_inv, matrix_G_device_temp_shape)
|
||||
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
|
||||
|
||||
A_max = F.depend(A_max, g)
|
||||
G_max = F.depend(G_max, g)
|
||||
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
|
||||
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
|
||||
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
|
||||
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
|
||||
elif layer_type == FC:
|
||||
damping = self.sqrt(damping_step)
|
||||
matrix_A = matrix_A + damping * A_eye
|
||||
matrix_A_inv = self.cholesky(matrix_A)
|
||||
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
|
||||
matrix_G = self.mul(matrix_G, self.loss_scale)
|
||||
matrix_G = self.mul(matrix_G, self.batch_size_scale)
|
||||
matrix_G = matrix_G + damping * G_eye
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
|
||||
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv)
|
||||
A_max = self.fused_abs_max2(matrix_A_inv_max)
|
||||
matrix_A_inv = self.matrix_combine(matrix_A_inv)
|
||||
matrix_A_inv_shape = self.shape(matrix_A_inv)
|
||||
matrix_A_inv = self.reshape(matrix_A_inv,
|
||||
(matrix_A_inv_shape[0] / 16, 16,
|
||||
matrix_A_inv_shape[0] / 16, 16))
|
||||
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
|
||||
matrix_G_inv_max = P.CusFusedAbsMax1([1001, 1001])(matrix_G_inv)
|
||||
G_max = self.fused_abs_max2(matrix_G_inv_max)
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
|
||||
matrix_G_inv = P.Pad(((0, 7), (0, 7)))(matrix_G_inv)
|
||||
matrix_G_inv_shape = self.shape(matrix_G_inv)
|
||||
matrix_G_inv = self.reshape(matrix_G_inv,
|
||||
(matrix_G_inv_shape[0] / 16, 16,
|
||||
matrix_G_inv_shape[0] / 16, 16))
|
||||
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
|
||||
A_max = F.depend(A_max, g)
|
||||
G_max = F.depend(G_max, g)
|
||||
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
|
||||
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
|
||||
|
||||
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
|
||||
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
|
||||
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
|
||||
|
||||
def _get_second_gradients(self, new_grads, damping_step, gradients):
|
||||
"""get second gradients for thor"""
|
||||
params_len = len(self.params)
|
||||
for i in range(0, params_len - 1, 3):
|
||||
g = gradients[i]
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
layer_type = self.weight_layerType_idx_map[i]
|
||||
matrix_A = self.matrix_A[thor_layer_count]
|
||||
matrix_G = self.matrix_G[thor_layer_count]
|
||||
matrix_max = self.matrix_max_inv[thor_layer_count]
|
||||
if layer_type == FC:
|
||||
g = self.cube_matmul_left_fc(matrix_G, g)
|
||||
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
|
||||
new_grads = new_grads + (g, gradients[i + 1])
|
||||
elif layer_type == Conv:
|
||||
g = self.cube_matmul_left(matrix_G, g)
|
||||
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
|
||||
new_grads = new_grads + (g, gradients[i + 1], gradients[i + 2])
|
||||
return new_grads
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
if self.thor:
|
||||
matrix_A_allreduce = ()
|
||||
matrix_G_allreduce = ()
|
||||
matrix_A_max_allreduce = ()
|
||||
matrix_G_max_allreduce = ()
|
||||
matrix_A_allreduce, matrix_G_allreduce, matrix_A_max_allreduce, matrix_G_max_allreduce = \
|
||||
self._get_Ainv_Ginv_Amax_Gmax_list(gradients, damping_step, matrix_A_allreduce, matrix_G_allreduce,
|
||||
matrix_A_max_allreduce, matrix_G_max_allreduce)
|
||||
if self.is_distributed:
|
||||
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
|
||||
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
|
||||
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
|
||||
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
|
||||
|
||||
new_grads = ()
|
||||
for i in range(0, 160, 3):
|
||||
g = gradients[i]
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layerType_idx_map[i]
|
||||
temp_a = matrix_A_allreduce[thor_layer_count]
|
||||
temp_g = matrix_G_allreduce[thor_layer_count]
|
||||
matrix_A_inv_max = self.log(matrix_A_max_allreduce[thor_layer_count])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(matrix_G_max_allreduce[thor_layer_count])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(matrix_A_max_allreduce[thor_layer_count],
|
||||
matrix_G_max_allreduce[thor_layer_count])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
if layer_type == FC:
|
||||
g = self.cube_matmul_left_fc(temp_g, g)
|
||||
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
|
||||
elif layer_type == Conv:
|
||||
A_normalizer = self.A_normalizer[conv_layer_count]
|
||||
A_normalizer = F.depend(A_normalizer, g)
|
||||
temp_max = self.mul(temp_max, self.batch_size / A_normalizer)
|
||||
g = self.cube_matmul_left(temp_g, g)
|
||||
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
|
||||
self.assign(self.matrix_A[thor_layer_count], temp_a)
|
||||
self.assign(self.matrix_G[thor_layer_count], temp_g)
|
||||
self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
|
||||
if i == 159:
|
||||
new_grads = new_grads + (g, gradients[i + 1])
|
||||
else:
|
||||
new_grads = new_grads + (g, gradients[i + 1], gradients[i + 2])
|
||||
gradients = new_grads
|
||||
else:
|
||||
new_grads = ()
|
||||
gradients = self._get_second_gradients(new_grads, damping_step, gradients)
|
||||
|
||||
self.cov_step = self.cov_step + self.one
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
||||
return success
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train and evaluate resnet50 network on imagenet dataset"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Process, Queue
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.nn.optim import thor
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
|
||||
from tests.st.networks.models.resnet50.src.metric import DistAccuracy, ClassifyCorrectCell
|
||||
from tests.st.networks.models.resnet50.src.dataset import create_dataset
|
||||
from tests.st.networks.models.resnet50.src.lr_generator import get_learning_rate
|
||||
from tests.st.networks.models.resnet50.src.config import config
|
||||
from tests.st.networks.models.resnet50.src.CrossEntropySmooth import CrossEntropySmooth
|
||||
from tests.st.networks.models.resnet50.src_thor.config import config as thor_config
|
||||
from tests.st.networks.models.resnet50.src_thor.dataset import create_dataset2 as create_dataset_thor
|
||||
from tests.st.networks.models.resnet50.src.resnet import resnet50
|
||||
|
||||
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
|
||||
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"
|
||||
dataset_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/train"
|
||||
eval_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val"
|
||||
|
||||
np.random.seed(1)
|
||||
ds.config.set_seed(1)
|
||||
os.environ['GLOG_v'] = str(2)
|
||||
|
||||
|
||||
def get_thor_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
|
||||
"""get_model_lr"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
for i in range(total_steps):
|
||||
epoch = (i + 1) / steps_per_epoch
|
||||
base = (1.0 - float(epoch) / total_epochs) ** decay
|
||||
lr_local = lr_init * base
|
||||
if epoch >= decay_epochs:
|
||||
lr_local = lr_local * 0.5
|
||||
if epoch >= decay_epochs + 1:
|
||||
lr_local = lr_local * 0.5
|
||||
lr_each_step.append(lr_local)
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_thor_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
|
||||
"""get_model_damping"""
|
||||
damping_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
for step in range(total_steps):
|
||||
epoch = (step + 1) / steps_per_epoch
|
||||
damping_here = damping_init * (decay_rate ** (epoch / 10))
|
||||
damping_each_step.append(damping_here)
|
||||
current_step = global_step
|
||||
damping_each_step = np.array(damping_each_step).astype(np.float32)
|
||||
damping_now = damping_each_step[current_step:]
|
||||
return damping_now
|
||||
|
||||
|
||||
class LossGet(Callback):
|
||||
def __init__(self, per_print_times, data_size):
|
||||
super(LossGet, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self._loss = 0.0
|
||||
self.data_size = data_size
|
||||
self._epoch = 0
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
self._epoch = cb_params.cur_epoch_num
|
||||
if isinstance(loss,
|
||||
(tuple, list) and isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray)):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
|
||||
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
self._loss = loss
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
self._per_step_mseconds = epoch_mseconds / self.data_size
|
||||
|
||||
def get_loss(self):
|
||||
return self._loss
|
||||
|
||||
def get_per_step_time(self):
|
||||
return self._per_step_mseconds
|
||||
|
||||
def get_epoch(self):
|
||||
return self._epoch
|
||||
|
||||
|
||||
def train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset, q):
|
||||
print("run_start", device_id)
|
||||
eval_interval = config.eval_interval
|
||||
step_size = dataset.get_dataset_size()
|
||||
acc = 0.0
|
||||
time_cost = 0.0
|
||||
for epoch_idx in range(0, int(epoch_size / eval_interval)):
|
||||
model.train(1, dataset, callbacks=loss_cb)
|
||||
eval_start = time.time()
|
||||
output = model.eval(eval_dataset)
|
||||
eval_cost = (time.time() - eval_start) * 1000
|
||||
acc = float(output["acc"])
|
||||
time_cost = loss_cb.get_per_step_time()
|
||||
loss = loss_cb.get_loss()
|
||||
print("the {} epoch's resnet result:\n "
|
||||
"device{}, training loss {}, acc {}, "
|
||||
"training per step cost {:.2f} ms, eval cost {:.2f} ms, "
|
||||
"total_cost {:.2f} ms".format(epoch_idx, device_id,
|
||||
loss, acc, time_cost,
|
||||
eval_cost,
|
||||
time_cost * step_size + eval_cost))
|
||||
q.put({'acc': acc, 'cost': time_cost})
|
||||
|
||||
|
||||
def train_process(q, device_id, epoch_size, device_num, enable_hccl):
|
||||
os.system("mkdir " + str(device_id))
|
||||
os.chdir(str(device_id))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
|
||||
os.environ['RANK_ID'] = str(device_id)
|
||||
os.environ['RANK_SIZE'] = str(device_num)
|
||||
if enable_hccl:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, all_reduce_fusion_config=[107, 160])
|
||||
init()
|
||||
|
||||
# network
|
||||
|
||||
net = resnet50(class_num=config.class_num)
|
||||
|
||||
# evaluation network
|
||||
dist_eval_network = ClassifyCorrectCell(net)
|
||||
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
||||
# loss
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
# train dataset
|
||||
dataset = create_dataset(dataset_path=dataset_path, do_train=True, repeat_num=1, batch_size=config.batch_size)
|
||||
|
||||
# evaluation dataset
|
||||
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
|
||||
repeat_num=1, batch_size=config.eval_batch_size)
|
||||
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
# learning rate
|
||||
lr = Tensor(get_learning_rate(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode))
|
||||
|
||||
# optimizer
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params, 'weight_decay': 0.0},
|
||||
{'order_params': net.trainable_params()}]
|
||||
|
||||
if config.use_lars:
|
||||
momentum = nn.Momentum(group_params, lr, config.momentum,
|
||||
loss_scale=config.loss_scale, use_nesterov=config.use_nesterov)
|
||||
opt = nn.LARS(momentum, epsilon=config.lars_epsilon, coefficient=config.lars_coefficient,
|
||||
lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name)
|
||||
|
||||
else:
|
||||
opt = nn.Momentum(group_params, lr, config.momentum,
|
||||
loss_scale=config.loss_scale, use_nesterov=config.use_nesterov)
|
||||
|
||||
# model
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, amp_level="O2", keep_batchnorm_fp32=False,
|
||||
metrics={'acc': DistAccuracy(batch_size=config.eval_batch_size, device_num=device_num)},
|
||||
eval_network=dist_eval_network)
|
||||
|
||||
# callbacks
|
||||
loss_cb = LossGet(1, step_size)
|
||||
train_and_eval(device_id, epoch_size, model, dataset, loss_cb, eval_dataset, q)
|
||||
|
||||
|
||||
def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
|
||||
os.system("mkdir " + str(device_id))
|
||||
os.chdir(str(device_id))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=device_id)
|
||||
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH_2
|
||||
os.environ['RANK_ID'] = str(device_id - 4)
|
||||
os.environ['RANK_SIZE'] = str(device_num)
|
||||
if enable_hccl:
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, all_reduce_fusion_config=[85, 160])
|
||||
init()
|
||||
|
||||
# network
|
||||
net = resnet50(thor_config.class_num)
|
||||
|
||||
if not thor_config.label_smooth:
|
||||
thor_config.label_smooth_factor = 0.0
|
||||
|
||||
# loss
|
||||
loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=thor_config.label_smooth_factor,
|
||||
num_classes=thor_config.class_num)
|
||||
|
||||
# train dataset
|
||||
dataset = create_dataset_thor(dataset_path=dataset_path, do_train=True,
|
||||
batch_size=thor_config.batch_size, train_image_size=thor_config.train_image_size,
|
||||
eval_image_size=thor_config.eval_image_size, target="Ascend",
|
||||
distribute=True)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# loss scale
|
||||
loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
# learning rate
|
||||
lr = get_thor_lr(0, 0.05803, 4.04839, 53, 5004, decay_epochs=39)
|
||||
damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004)
|
||||
# optimizer
|
||||
split_indices = [26, 53]
|
||||
opt = thor(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
|
||||
thor_config.batch_size, split_indices=split_indices, frequency=thor_config.frequency)
|
||||
|
||||
# evaluation network
|
||||
dist_eval_network = ClassifyCorrectCell(net)
|
||||
# model
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
|
||||
metrics={'acc': DistAccuracy(batch_size=thor_config.eval_batch_size, device_num=device_num)},
|
||||
amp_level="O2", keep_batchnorm_fp32=False,
|
||||
eval_network=dist_eval_network)
|
||||
|
||||
model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
|
||||
# callbacks
|
||||
loss_cb = LossGet(1, step_size)
|
||||
|
||||
# train and eval
|
||||
print("run_start", device_id)
|
||||
model.train(2, dataset, callbacks=loss_cb,
|
||||
sink_size=dataset.get_dataset_size(), dataset_sink_mode=True)
|
||||
time_cost = loss_cb.get_per_step_time()
|
||||
loss = loss_cb.get_loss()
|
||||
epoch_idx = loss_cb.get_epoch()
|
||||
print("the {} epoch's resnet result:\n "
|
||||
"device{}, training loss {}, "
|
||||
"training per step cost {:.2f} ms, total_cost {:.2f} ms".format(epoch_idx, device_id,
|
||||
loss, time_cost, time_cost * step_size))
|
||||
q.put({'loss': loss, 'cost': time_cost})
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def resnet_end(device_num, q):
|
||||
acc = 0.0
|
||||
cost = 0.0
|
||||
for i in range(device_num):
|
||||
assert not q.empty()
|
||||
output = q.get()
|
||||
acc += output['acc']
|
||||
cost += output['cost']
|
||||
acc = acc / device_num
|
||||
cost = cost / device_num
|
||||
|
||||
for i in range(device_num):
|
||||
os.system("rm -rf " + str(i))
|
||||
print("End training...")
|
||||
assert acc > 0.1
|
||||
assert cost < 26
|
||||
|
||||
|
||||
def thor_end(device_num, q):
|
||||
thor_loss = 0.0
|
||||
thor_cost = 0.0
|
||||
for i in range(device_num):
|
||||
output = q.get()
|
||||
thor_loss += output['loss']
|
||||
thor_cost += output['cost']
|
||||
thor_loss = thor_loss / device_num
|
||||
thor_cost = thor_cost / device_num
|
||||
|
||||
for i in range(0, device_num):
|
||||
os.system("rm -rf " + str(i))
|
||||
print("End training...")
|
||||
assert thor_loss < 7
|
||||
assert thor_cost < 30
|
||||
|
||||
|
||||
def test_resnet_imagenet_and_thor_4p():
|
||||
"""
|
||||
Feature: Resnet50 network.
|
||||
Description: Train and evaluate resnet50 network on imagenet dataset.
|
||||
Expectation: accuracy > 0.1, time cost < 26.
|
||||
"""
|
||||
context.set_context(enable_graph_kernel=False, enable_sparse=False)
|
||||
context.reset_auto_parallel_context()
|
||||
context.reset_ps_context()
|
||||
|
||||
q = Queue()
|
||||
q2 = Queue()
|
||||
device_num = 4
|
||||
epoch_size = 2
|
||||
epoch_size_2 = 1
|
||||
enable_hccl = True
|
||||
process = []
|
||||
process2 = []
|
||||
for i in range(device_num):
|
||||
device_id = i
|
||||
process.append(Process(target=train_process,
|
||||
args=(q, device_id, epoch_size, device_num, enable_hccl)))
|
||||
process2.append(Process(target=train_process_thor,
|
||||
args=(q2, device_id + 4, epoch_size_2, device_num, enable_hccl)))
|
||||
cpu_count = os.cpu_count()
|
||||
half_cpu_count = cpu_count // 2
|
||||
each_cpu_count = cpu_count // device_num
|
||||
for i in range(device_num):
|
||||
process[i].start()
|
||||
process2[i].start()
|
||||
if each_cpu_count > 1:
|
||||
cpu_start = each_cpu_count * i
|
||||
cpu_end = each_cpu_count * (i + 1)
|
||||
process_cpu = [x for x in range(cpu_start, cpu_end)]
|
||||
process2_cpu = [x for x in range(cpu_start + half_cpu_count, cpu_end + half_cpu_count)]
|
||||
pid1 = process[i].pid
|
||||
pid2 = process2[i].pid
|
||||
os.sched_setaffinity(pid1, set(process_cpu))
|
||||
os.sched_setaffinity(pid2, set(process2_cpu))
|
||||
print("Waiting for all subprocesses done...")
|
||||
|
||||
for i in range(device_num):
|
||||
process[i].join()
|
||||
process2[i].join()
|
||||
# resnet
|
||||
resnet_end(device_num, q)
|
||||
# thor
|
||||
thor_end(device_num, q2)
|
Loading…
Reference in New Issue