From 429c88a46210a29f3bce27bc2df587a89f506cde Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 8 Jun 2020 15:07:11 +0800 Subject: [PATCH] add resnet50 imagenet st --- tests/st/networks/__init__.py | 0 tests/st/networks/models/__init__.py | 0 tests/st/networks/models/resnet50/__init__.py | 0 .../networks/models/resnet50/src/__init__.py | 0 .../st/networks/models/resnet50/src/config.py | 47 ++ .../networks/models/resnet50/src/dataset.py | 79 ++ .../models/resnet50/src/lr_generator.py | 87 ++ .../st/networks/models/resnet50/src/metric.py | 132 ++++ .../models/resnet50/src_thor/__init__.py | 0 .../models/resnet50/src_thor/config.py | 39 + .../models/resnet50/src_thor/dataset.py | 82 ++ .../resnet50/src_thor/dataset_helper.py | 120 +++ .../resnet50/src_thor/grad_reducer_thor.py | 184 +++++ .../models/resnet50/src_thor/lr_generator.py | 88 +++ .../models/resnet50/src_thor/metric.py | 132 ++++ .../models/resnet50/src_thor/model_thor.py | 743 ++++++++++++++++++ .../models/resnet50/src_thor/resnet.py | 359 +++++++++ .../networks/models/resnet50/src_thor/thor.py | 201 +++++ .../models/resnet50/src_thor/thor_layer.py | 481 ++++++++++++ .../models/resnet50/test_resnet50_imagenet.py | 385 +++++++++ tests/st/tbe_networks/test_resnet_cifar_1p.py | 7 +- tests/st/tbe_networks/test_resnet_cifar_8p.py | 6 +- 22 files changed, 3161 insertions(+), 11 deletions(-) create mode 100644 tests/st/networks/__init__.py create mode 100644 tests/st/networks/models/__init__.py create mode 100644 tests/st/networks/models/resnet50/__init__.py create mode 100644 tests/st/networks/models/resnet50/src/__init__.py create mode 100755 tests/st/networks/models/resnet50/src/config.py create mode 100755 tests/st/networks/models/resnet50/src/dataset.py create mode 100755 tests/st/networks/models/resnet50/src/lr_generator.py create mode 100644 tests/st/networks/models/resnet50/src/metric.py create mode 100644 tests/st/networks/models/resnet50/src_thor/__init__.py create mode 100644 tests/st/networks/models/resnet50/src_thor/config.py create mode 100644 tests/st/networks/models/resnet50/src_thor/dataset.py create mode 100644 tests/st/networks/models/resnet50/src_thor/dataset_helper.py create mode 100644 tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py create mode 100644 tests/st/networks/models/resnet50/src_thor/lr_generator.py create mode 100644 tests/st/networks/models/resnet50/src_thor/metric.py create mode 100644 tests/st/networks/models/resnet50/src_thor/model_thor.py create mode 100644 tests/st/networks/models/resnet50/src_thor/resnet.py create mode 100644 tests/st/networks/models/resnet50/src_thor/thor.py create mode 100644 tests/st/networks/models/resnet50/src_thor/thor_layer.py create mode 100644 tests/st/networks/models/resnet50/test_resnet50_imagenet.py diff --git a/tests/st/networks/__init__.py b/tests/st/networks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/networks/models/__init__.py b/tests/st/networks/models/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/networks/models/resnet50/__init__.py b/tests/st/networks/models/resnet50/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/networks/models/resnet50/src/__init__.py b/tests/st/networks/models/resnet50/src/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/networks/models/resnet50/src/config.py b/tests/st/networks/models/resnet50/src/config.py new file mode 100755 index 00000000000..fbb3e83ba35 --- /dev/null +++ b/tests/st/networks/models/resnet50/src/config.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config = ed({ + "class_num": 1001, + "batch_size": 32, + "eval_interval": 1, + "eval_batch_size": 50, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "use_nesterov": True, + "epoch_size": 90, + "pretrained_epoch_size": 1, + "buffer_size": 1000, + "image_height": 224, + "image_width": 224, + "save_checkpoint": False, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0, + "lr_max": 0.1, + "use_lars": True, + "lars_epsilon": 1e-8, + "lars_coefficient": 0.001 +}) diff --git a/tests/st/networks/models/resnet50/src/dataset.py b/tests/st/networks/models/resnet50/src/dataset.py new file mode 100755 index 00000000000..ae15f4159e6 --- /dev/null +++ b/tests/st/networks/models/resnet50/src/dataset.py @@ -0,0 +1,79 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""create train or eval dataset.""" + +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): + """ + create a train or eval dataset. + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + + Returns: + dataset + """ + + device_num = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize((256, 256)), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + return ds diff --git a/tests/st/networks/models/resnet50/src/lr_generator.py b/tests/st/networks/models/resnet50/src/lr_generator.py new file mode 100755 index 00000000000..5f3d5f571fb --- /dev/null +++ b/tests/st/networks/models/resnet50/src/lr_generator.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def get_learning_rate(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'poly': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + + return learning_rate diff --git a/tests/st/networks/models/resnet50/src/metric.py b/tests/st/networks/models/resnet50/src/metric.py new file mode 100644 index 00000000000..4cf93e15aaa --- /dev/null +++ b/tests/st/networks/models/resnet50/src/metric.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluation metric.""" + +from mindspore.communication.management import GlobalComm +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.common.dtype as mstype + + +class ClassifyCorrectCell(nn.Cell): + r""" + Cell that returns correct count of the prediction in classification network. + This Cell accepts a network as arguments. + It returns orrect count of the prediction to calculate the metrics. + + Args: + network (Cell): The network Cell. + + Inputs: + - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + + Outputs: + Tuple, containing a scalar correct count of the prediction + + Examples: + >>> # For a defined network Net without loss function + >>> net = Net() + >>> eval_net = nn.ClassifyCorrectCell(net) + """ + + def __init__(self, network): + super(ClassifyCorrectCell, self).__init__(auto_prefix=False) + self._network = network + self.argmax = P.Argmax() + self.equal = P.Equal() + self.cast = P.Cast() + self.reduce_sum = P.ReduceSum() + self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + + def construct(self, data, label): + outputs = self._network(data) + y_pred = self.argmax(outputs) + y_pred = self.cast(y_pred, mstype.int32) + y_correct = self.equal(y_pred, label) + y_correct = self.cast(y_correct, mstype.float32) + y_correct = self.reduce_sum(y_correct) + total_correct = self.allreduce(y_correct) + return (total_correct,) + + +class DistAccuracy(nn.Metric): + r""" + Calculates the accuracy for classification data in distributed mode. + The accuracy class creates two local variables, correct number and total number that are used to compute the + frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an + idempotent operation that simply divides correct number by total number. + + .. math:: + + \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} + + {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} + + Args: + eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label). + + Examples: + >>> y_correct = Tensor(np.array([20])) + >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) + >>> metric.clear() + >>> metric.update(y_correct) + >>> accuracy = metric.eval() + """ + + def __init__(self, batch_size, device_num): + super(DistAccuracy, self).__init__() + self.clear() + self.batch_size = batch_size + self.device_num = device_num + + def clear(self): + """Clears the internal evaluation result.""" + self._correct_num = 0 + self._total_num = 0 + + def update(self, *inputs): + """ + Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. + + Args: + inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. + `y_correct` is the right prediction count that gathered from all devices + it's a scalar in float type + + Raises: + ValueError: If the number of the input is not 1. + """ + + if len(inputs) != 1: + raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) + y_correct = self._convert_data(inputs[0]) + self._correct_num += y_correct + self._total_num += self.batch_size * self.device_num + + def eval(self): + """ + Computes the accuracy. + + Returns: + Float, the computed result. + + Raises: + RuntimeError: If the sample size is 0. + """ + + if self._total_num == 0: + raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') + return self._correct_num / self._total_num diff --git a/tests/st/networks/models/resnet50/src_thor/__init__.py b/tests/st/networks/models/resnet50/src_thor/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/st/networks/models/resnet50/src_thor/config.py b/tests/st/networks/models/resnet50/src_thor/config.py new file mode 100644 index 00000000000..cd1d1cef0cd --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/config.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +config = ed({ + "class_num": 1000, + "batch_size": 32, + "loss_scale": 128, + "momentum": 0.9, + "weight_decay": 5e-4, + "epoch_size": 45, + "buffer_size": 1000, + "image_height": 224, + "image_width": 224, + "save_checkpoint": True, + "save_checkpoint_steps": 5004, + "keep_checkpoint_max": 20, + "save_checkpoint_path": "./", + "label_smooth": 1, + "label_smooth_factor": 0.1, + "frequency": 834, + "eval_interval": 1, + "eval_batch_size": 32 +}) diff --git a/tests/st/networks/models/resnet50/src_thor/dataset.py b/tests/st/networks/models/resnet50/src_thor/dataset.py new file mode 100644 index 00000000000..091172e62c8 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/dataset.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""create train or eval dataset.""" + +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as dataset +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.c_transforms as C + +dataset.config.set_seed(1) + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): + """ + Create a train or eval dataset. + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + + Returns: + dataset + """ + + device_num = int(os.getenv("RANK_SIZE")) + rank_id = int(os.getenv("RANK_ID")) + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize((256, 256)), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + return ds diff --git a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py new file mode 100644 index 00000000000..e02dcc6acb7 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dataset help for minddata dataset""" +from mindspore._checkparam import check_bool +from mindspore.parallel._utils import _get_device_num, _get_parallel_mode +from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ + _to_full_shapes +from mindspore.train.parallel_utils import ParallelMode + + +class DatasetHelper: + """ + Help function to use the Minddata dataset. + + According to different context, change the iter of dataset, to use the same for loop in different context. + + Note: + The iter of DatasetHelper will give one epoch data. + + Args: + dataset (DataSet): The dataset. + dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. + Default: True. + iter_first_order (int): The iteration of first-order subgraph. + Default: 1. + + Examples: + >>> dataset_helper = DatasetHelper(dataset) + >>> for inputs in dataset_helper: + >>> outputs = network(*inputs) + """ + + def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): + check_bool(dataset_sink_mode) + self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) + + def __iter__(self): + return self.iter.__iter__() + + # A temp solution for loop sink. Delete later + def types_shapes(self): + """Get the types and shapes from dataset on current config.""" + return self.iter.types_shapes() + + def loop_size(self): + """Get loop_size for every iteration.""" + return self.iter.loop_size + + +class _DatasetIter: + """Base iter for dataset help""" + + def __init__(self, dataset): + self.loop_size = 1 + if not hasattr(dataset, '__ME_INITED__'): + if not hasattr(dataset, '__loop_size__'): + self.loop_size = dataset.get_dataset_size() + else: + self.loop_size = dataset.__loop_size__ + dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + + self.ind = 0 + self.dataset = dataset + dataset_types, dataset_shapes = _get_types_and_shapes(dataset) + self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes + + def __iter__(self): + self.ind = 0 + return self + + def __next__(self): + if self.ind >= self.loop_count: + raise StopIteration() + self.ind += 1 + return self.op() + + def types_shapes(self): + return self.dataset_types, self.dataset_shapes + + def get_loop_count(self, dataset): + loop_count = 1 + if hasattr(dataset, '__loop_size__'): + loop_size = dataset.__loop_size__ + if dataset.get_dataset_size() % loop_size != 0: + raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' + f'loop_size {loop_size} are not matched.') + loop_count = int(dataset.get_dataset_size() / loop_size) + return loop_count + + +class _DatasetIterMSLoopSink(_DatasetIter): + """Iter for context (device_target=Ascend)""" + + def __init__(self, dataset, iter_first_order): + super(_DatasetIterMSLoopSink, self).__init__(dataset) + loop_size = dataset.__loop_size__ + iter_first_order + self.loop_count = int(dataset.get_dataset_size() / loop_size * 2) + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to + # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number + # times the batch dimension of tensors for run. Now only support LoopSink. + if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + device_num = _get_device_num() + self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) + + def op(): + return tuple() + + self.op = op diff --git a/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py b/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py new file mode 100644 index 00000000000..0b160c02f23 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py @@ -0,0 +1,184 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""grad_reducer_thor""" +import mindspore.common.dtype as mstype +from mindspore.communication.management import GlobalComm, get_group_size +from mindspore.nn.cell import Cell +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp + +reduce_opt = C.MultitypeFuncGraph("reduce_opt") + +_all_reduce_A = AllReduce() + + +def _init_optimizer_allreduce(group): + global _all_reduce_A + _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + _all_reduce_A.add_prim_attr('fusion', group) + + +@reduce_opt.register("Function", "Number", "Tensor") +def _tensors_allreduce_mean(mul, degree, grad): + degree = F.scalar_cast(degree, F.dtype(grad)) + grad = _all_reduce_A(grad) + cast_op = P.Cast() + return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) + + +@reduce_opt.register("Bool", "Tensor") +def _tensors_allreduce(allreduce_filter, grad): + if allreduce_filter: + return _all_reduce_A(grad) + return grad + + +_get_datatype = C.MultitypeFuncGraph("_get_datatype") + + +@_get_datatype.register("Tensor") +def _tensors_get_datatype(grad): + """ + Acquire gradient datatype. + + Args: + grad (Tensor): The gradient tensor before operation. + + Returns: + mstype, the datatype of gradient. + """ + return F.dtype(grad) + + +_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") + + +@_cast_datatype.register("TypeType", "Tensor") +def _tensors_cast_datatype(datatype, grad): + """ + Cast gradient to datatype. + + Args: + datatype (mstype): the destination datatype of gradient. + grad (Tensor): The gradient tensor before operation. + + Returns: + Tensor, the gradient tensor after operation. + """ + return F.cast(grad, datatype) + + +class DistributedGradReducerThor(Cell): + """ + A distributed optimizer. + + Constructs a gradient reducer Cell, which applies communication and average operations on + single-process gradient values. + + Args: + parameters (list): the parameters to be updated. + group (int): the different group to allreduce. + mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. + degree (int): The mean coefficient. Usually it equals to device number. Default: None. + + Raises: + ValueError: If degree is not a int or less than 0. + + Examples: + >>> from mindspore.communication import init, get_group_size + >>> from mindspore.ops import composite as C + >>> from mindspore.ops import operations as P + >>> from mindspore.ops import functional as F + >>> from mindspore import context + >>> from mindspore import nn + >>> from mindspore import ParallelMode, ParameterTuple + >>> + >>> device_id = int(os.environ["DEVICE_ID"]) + >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, + >>> device_id=int(device_id), enable_hccl=True) + >>> init() + >>> context.reset_auto_parallel_context() + >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) + >>> + >>> + >>> class TrainingWrapper(nn.Cell): + >>> def __init__(self, network, optimizer, sens=1.0): + >>> super(TrainingWrapper, self).__init__(auto_prefix=False) + >>> self.network = network + >>> self.network.add_flags(defer_inline=True) + >>> self.weights = ParameterTuple(network.trainable_params()) + >>> self.optimizer = optimizer + >>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + >>> self.sens = sens + >>> self.reducer_flag = False + >>> self.grad_reducer = None + >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, + >>> ParallelMode.HYBRID_PARALLEL]: + >>> self.reducer_flag = True + >>> if self.reducer_flag: + >>> mean = context.get_auto_parallel_context("mirror_mean") + >>> if mean.get_device_num_is_set(): + >>> degree = context.get_auto_parallel_context("device_num") + >>> else: + >>> degree = get_group_size() + >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + >>> + >>> def construct(self, *args): + >>> weights = self.weights + >>> loss = self.network(*args) + >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + >>> grads = self.grad(self.network, weights)(*args, sens) + >>> if self.reducer_flag: + >>> # apply grad reducer on grads + >>> grads = self.grad_reducer(grads) + >>> return F.depend(loss, self.optimizer(grads)) + >>> + >>> network = Net() + >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> train_cell = TrainingWrapper(network, optimizer) + >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) + >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) + >>> grads = train_cell(inputs, label) + """ + + def __init__(self, parameters, group, mean=True, degree=None): + super(DistributedGradReducerThor, self).__init__(auto_prefix=False) + self.hyper_map = C.HyperMap() + self.mul = P.Mul() + if degree is None: + self.degree = get_group_size() + else: + if not isinstance(degree, int) or degree <= 0: + raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") + self.degree = degree + self.mean = mean + self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) + _init_optimizer_allreduce(group) + + def construct(self, grads): + # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the + # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, + # and cast back after the operation. + datatypes = self.hyper_map(F.partial(_get_datatype), grads) + grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) + + if self.mean: + new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) + else: + new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) + + new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) + return new_grad diff --git a/tests/st/networks/models/resnet50/src_thor/lr_generator.py b/tests/st/networks/models/resnet50/src_thor/lr_generator.py new file mode 100644 index 00000000000..f56bdf1e150 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/lr_generator.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math + +import numpy as np + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + if lr_decay_mode == 'steps': + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + elif lr_decay_mode == 'poly': + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + else: + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + + return learning_rate diff --git a/tests/st/networks/models/resnet50/src_thor/metric.py b/tests/st/networks/models/resnet50/src_thor/metric.py new file mode 100644 index 00000000000..1834470fb4b --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/metric.py @@ -0,0 +1,132 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluation metric.""" + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore.communication.management import GlobalComm +from mindspore.ops import operations as P + + +class ClassifyCorrectCell(nn.Cell): + r""" + Cell that returns correct count of the prediction in classification network. + This Cell accepts a network as arguments. + It returns orrect count of the prediction to calculate the metrics. + + Args: + network (Cell): The network Cell. + + Inputs: + - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. + + Outputs: + Tuple, containing a scalar correct count of the prediction + + Examples: + >>> # For a defined network Net without loss function + >>> net = Net() + >>> eval_net = nn.ClassifyCorrectCell(net) + """ + + def __init__(self, network): + super(ClassifyCorrectCell, self).__init__(auto_prefix=False) + self._network = network + self.argmax = P.Argmax() + self.equal = P.Equal() + self.cast = P.Cast() + self.reduce_sum = P.ReduceSum() + self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) + + def construct(self, data, label): + outputs = self._network(data) + y_pred = self.argmax(outputs) + y_pred = self.cast(y_pred, mstype.int32) + y_correct = self.equal(y_pred, label) + y_correct = self.cast(y_correct, mstype.float32) + y_correct = self.reduce_sum(y_correct) + total_correct = self.allreduce(y_correct) + return (total_correct,) + + +class DistAccuracy(nn.Metric): + r""" + Calculates the accuracy for classification data in distributed mode. + The accuracy class creates two local variables, correct number and total number that are used to compute the + frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an + idempotent operation that simply divides correct number by total number. + + .. math:: + + \text{accuracy} =\frac{\text{true_positive} + \text{true_negative}} + + {\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}} + + Args: + batch_size (int): eval batch size. + device_num (int): device number to eval. + Examples: + >>> y_correct = Tensor(np.array([20])) + >>> metric = nn.DistAccuracy(batch_size=3, device_num=8) + >>> metric.clear() + >>> metric.update(y_correct) + >>> accuracy = metric.eval() + """ + + def __init__(self, batch_size, device_num): + super(DistAccuracy, self).__init__() + self.clear() + self.batch_size = batch_size + self.device_num = device_num + + def clear(self): + """Clears the internal evaluation result.""" + self._correct_num = 0 + self._total_num = 0 + + def update(self, *inputs): + """ + Updates the internal evaluation result :math:`y_{pred}` and :math:`y`. + + Args: + inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`. + `y_correct` is the right prediction count that gathered from all devices + it's a scalar in float type + + Raises: + ValueError: If the number of the input is not 1. + """ + + if len(inputs) != 1: + raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs))) + y_correct = self._convert_data(inputs[0]) + self._correct_num += y_correct + self._total_num += self.batch_size * self.device_num + + def eval(self): + """ + Computes the accuracy. + + Returns: + Float, the computed result. + + Raises: + RuntimeError: If the sample size is 0. + """ + + if self._total_num == 0: + raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.') + return self._correct_num / self._total_num diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py new file mode 100644 index 00000000000..9bb9639bc8c --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -0,0 +1,743 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Model.""" + +import numpy as np +from mindspore import context +from mindspore import log as logger +from mindspore import nn +from mindspore._c_expression import init_exec_dataset +from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool +from mindspore.common import dtype as mstype +from mindspore.common.dtype import pytype_to_dtype +from mindspore.common.tensor import Tensor +from mindspore.nn.metrics import Loss +from mindspore.nn.metrics import get_metrics +from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell +from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ + _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check +from mindspore.train import amp +from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks +from mindspore.train.parallel_utils import ParallelMode + +from .dataset_helper import DatasetHelper + + +def _convert_type(types): + """ + Convert from numpy type to tensor type. + + Args: + types (list): Numpy type list of element in dataset. + + Returns: + list, list of element in dataset. + """ + ms_types = [] + for np_type in types: + ms_type = pytype_to_dtype(np_type) + ms_types.append(ms_type) + return ms_types + + +def _get_types_and_shapes(dataset): + """Get dataset types and shapes.""" + dataset_types = _convert_type(dataset.output_types()) + dataset_shapes = dataset.output_shapes() + return dataset_types, dataset_shapes + + +def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): + """Initialize and execute the dataset graph.""" + batch_size = exec_dataset.get_batch_size() + input_indexs = exec_dataset.input_indexs + + # transform data format + dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) + init_exec_dataset(exec_dataset.__ME_INITED__, + dataset_size, + batch_size, + dataset_types, + dataset_shapes, + input_indexs, + phase=phase, + need_run=False) + + +class Model: + """ + High-Level API for Training or Testing. + + `Model` groups layers into an object with training and inference features. + + Args: + network (Cell): The training or testing network. + loss_fn (Cell): Objective function, if loss_fn is None, the + network should contain the logic of loss and grads calculation, and the logic + of parallel if needed. Default: None. + optimizer (Cell): Optimizer for updating the weights. Default: None. + metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during + training and testing. eg: {'accuracy', 'recall'}. Default: None. + eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as + `eval_network`. Default: None. + eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of + `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three + elements, representing the positions of loss value, predict value and label, the loss + value would be passed to `Loss` metric, predict value and label would be passed to other + metric. Default: None. + amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed + precision training. Supports [O0, O2]. Default: "O0". + + - O0: Do not change. + - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. + + loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else + scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. + e.g. Use `loss_scale_manager=None` to set the value. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + >>> self.bn = nn.BatchNorm2d(64) + >>> self.relu = nn.ReLU() + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 + >>> + >>> def construct(self, x): + >>> x = self.conv(x) + >>> x = self.bn(x) + >>> x = self.relu(x) + >>> x = self.flatten(x) + >>> out = self.fc(x) + >>> return out + >>> + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + >>> dataset = get_dataset() + >>> model.train(2, dataset) + """ + + def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, + eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): + self._network = network + self._loss_fn = loss_fn + self._optimizer = optimizer + self._loss_scale_manager = None + self._loss_scale_manager_set = False + self._keep_bn_fp32 = True + self._check_kwargs(kwargs) + self._amp_level = amp_level + self._process_amp_args(kwargs) + self._parallel_mode = _get_parallel_mode() + self._device_number = _get_device_num() + self._global_rank = _get_global_rank() + self._parameter_broadcast = _get_parameter_broadcast() + self._frequency = frequency + self._stop_epoch = stop_epoch + self._has_do_dataset_init = False + + self._train_network = self._build_train_network() + self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() + + def _process_amp_args(self, kwargs): + if self._amp_level == "O0": + self._keep_bn_fp32 = False + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] + if 'loss_scale_manager' in kwargs: + self._loss_scale_manager = kwargs['loss_scale_manager'] + self._loss_scale_manager_set = True + + def _check_kwargs(self, kwargs): + for arg in kwargs: + if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: + raise ValueError(f"Unsupport arg '{arg}'") + + def _build_train_network(self): + """Build train network""" + network = self._network + if self._optimizer: + if self._loss_scale_manager_set: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) + else: + network = amp.build_train_network(network, + self._optimizer, + self._loss_fn, + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) + elif self._loss_fn: + network = nn.WithLossCell(network, self._loss_fn) + # If need to check if loss_fn is not None, but optimizer is None + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() + return network + + def _build_eval_network(self, metrics, eval_network, eval_indexes): + """Build the network for evaluation.""" + self._metric_fns = get_metrics(metrics) + if not self._metric_fns: + return + + if eval_network is not None: + if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3): + raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \ + must be three. But got {}".format(eval_indexes)) + + self._eval_network = eval_network + self._eval_indexes = eval_indexes + else: + if self._loss_fn is None: + raise ValueError("loss_fn can not be None.") + self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") + self._eval_indexes = [0, 1, 2] + + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._eval_network.set_auto_parallel() + + def _build_predict_network(self): + """Build the network for prediction.""" + self._predict_network = self._network + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._predict_network = _VirtualDatasetCell(self._network) + self._predict_network.set_auto_parallel() + + def _clear_metrics(self): + """Clear metrics local values.""" + for metric in self._metric_fns.values(): + metric.clear() + + def _update_metrics(self, outputs): + """Update metrics local values.""" + if not isinstance(outputs, tuple): + raise ValueError("The `outputs` is not tuple.") + + if self._eval_indexes is not None and len(outputs) < 3: + raise ValueError("The length of `outputs` must be greater than or equal to 3, \ + but got {}".format(len(outputs))) + + for metric in self._metric_fns.values(): + if self._eval_indexes is None: + metric.update(*outputs) + else: + if isinstance(metric, Loss): + metric.update(outputs[self._eval_indexes[0]]) + else: + metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]]) + + def _get_metrics(self): + """Get metrics local values.""" + metrics = dict() + for key, value in self._metric_fns.items(): + metrics[key] = value.eval() + return metrics + + def _get_scaling_sens(self): + """get the scaling sens""" + scaling_sens = 1 + if self._loss_scale_manager is not None: + scaling_sens = self._loss_scale_manager.get_loss_scale() + if self._parallel_mode == ParallelMode.DATA_PARALLEL: + scaling_sens /= self._device_number + return scaling_sens + + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1): + """Initializes dataset.""" + need_wrap = False + if dataset_sink_mode: + # remove later to deal with loop sink + if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ + and not context.get_context("enable_ge"): + need_wrap = True + + if not is_train: + dataset.__loop_size__ = 1 + + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) + + # remove later to deal with loop sink + if need_wrap: + network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) + network.set_train(is_train) + network.phase = phase + + return dataset_helper, network + + def init(self, train_dataset=None, valid_dataset=None): + """ + Initializes compute graphs and data graphs with sink mode. + + Note: + Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. + + Args: + train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be + initialized. Default: None. + valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will + be initialized, and `metrics` in `Model` can not be None. Default: None. + + Examples: + >>> train_dataset = get_train_dataset() + >>> valid_dataset = get_valid_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) + >>> model.init(train_dataset, valid_dataset) + >>> model.train(2, train_dataset) + >>> model.eval(valid_dataset) + """ + if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": + raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') + + if not train_dataset and not valid_dataset: + raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') + + _device_number_check(self._parallel_mode, self._device_number) + + if train_dataset: + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + self._train_network.set_train() + self._train_network.phase = 'train' + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + iter_first_order = self._frequency - 1 + iter_second_order = 1 + train_dataset.__loop_size__ = iter_second_order + train_dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True, + iter_first_order=iter_first_order) + self._train_network = train_network + switch_branch_one = True + index = 0 + for inputs in train_dataset_helper: + if switch_branch_one: + self._train_network.add_flags_recursive(thor=True) + self._train_network.phase = 'train0' + else: + self._train_network.add_flags_recursive(thor=False) + self._train_network.phase = 'train1' + if not self._has_do_dataset_init: + _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') + self._has_do_dataset_init = True + switch_branch_one = not switch_branch_one + self._train_network.compile(*inputs) + if index >= 1: + break + index += 1 + + if valid_dataset: + if not self._metric_fns: + raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') + + self._eval_network.set_train(False) + self._eval_network.phase = 'eval' + valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + for inputs in valid_dataset_helper: + self._eval_network.compile(*inputs) + break + + def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): + """ + Training. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be + returned and passed to the network. Otherwise, a tuple (data, label) will + be returned, and the data and label are passed to the network and loss + function respectively. + callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. + """ + epoch = check_int_positive(epoch) + self._train_network.set_train() + + if self._parameter_broadcast: + self._train_network.set_broadcast_flag() + + # build callback list + list_callback = _build_callbacks(callbacks) + cb_params = _InternalCallbackParam() + cb_params.train_network = self._train_network + cb_params.epoch_num = epoch + cb_params.batch_num = train_dataset.get_dataset_size() + cb_params.mode = "train" + cb_params.loss_fn = self._loss_fn + cb_params.optimizer = self._optimizer + cb_params.parallel_mode = self._parallel_mode + cb_params.device_number = self._device_number + cb_params.train_dataset = train_dataset + cb_params.list_callback = list_callback + + if dataset_sink_mode: + if context.get_context("mode") == context.PYNATIVE_MODE: + logger.warning("The pynative mode cannot support dataset sink mode currently." + "So the training process will be performed with dataset not sink.") + self._train_process(epoch, train_dataset, list_callback, cb_params) + else: + self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) + else: + self._train_process(epoch, train_dataset, list_callback, cb_params) + + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): + """ + Training process. The data would be passed to network through dataset channel. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + list_callback (_ListCallback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + """ + iter_first_order = self._frequency - 1 + iter_second_order = 1 + train_dataset.__loop_size__ = iter_second_order + dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True, + iter_first_order=iter_first_order) + self._train_network = train_network + cb_params.train_network = self._train_network + cb_params.cur_step_num = 0 + + loop_size = dataset_helper.loop_size() + run_context = RunContext(cb_params) + list_callback.begin(run_context) + + # used to stop training for early stop, such as stopAtTIme or stopATStep + should_stop = False + switch_branch_one = True + for i in range(epoch): + cb_params.cur_epoch_num = i + 1 + list_callback.epoch_begin(run_context) + + # for data sink dataset_helper only iter once, other wise iter epoch_size times. + for inputs in dataset_helper: + list_callback.step_begin(run_context) + if switch_branch_one: + cb_params.cur_step_num += loop_size + self._train_network.add_flags_recursive(thor=True) + self._train_network.phase = 'train0' + else: + cb_params.cur_step_num += iter_first_order + self._train_network.add_flags_recursive(thor=False) + self._train_network.phase = 'train1' + if not self._has_do_dataset_init: + _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') + self._has_do_dataset_init = True + switch_branch_one = not switch_branch_one + outputs = self._train_network(*inputs) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + + list_callback.epoch_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + list_callback.end(run_context) + + def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): + """ + Training process. The data would be passed to network directly. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + list_callback (_ListCallback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + """ + dataset_helper, _ = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=False) + cb_params.cur_step_num = 0 + run_context = RunContext(cb_params) + list_callback.begin(run_context) + # used to stop training for early stop, such as stopAtTIme or stopATStep + should_stop = False + + for i in range(epoch): + cb_params.cur_epoch_num = i + 1 + + list_callback.epoch_begin(run_context) + + for next_element in dataset_helper: + len_element = len(next_element) + if self._loss_fn and len_element != 2: + raise ValueError("when loss_fn is not None, train_dataset should" + "return two elements, but got {}".format(len_element)) + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + + overflow = False + if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): + scaling_sens = self._get_scaling_sens() + next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) + + outputs = self._train_network(*next_element) + cb_params.net_outputs = outputs + if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): + _, overflow, _ = outputs + overflow = np.all(overflow.asnumpy()) + self._loss_scale_manager.update_loss_scale(overflow) + + list_callback.step_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + train_dataset.reset() + + list_callback.epoch_end(run_context) + should_stop = should_stop or run_context.get_stop_requested() + if should_stop: + break + + list_callback.end(run_context) + + def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): + """ + Training API where the iteration is controlled by python front-end. + + When setting pynative mode, the training process will be performed with dataset not sink. + + Note: + CPU is not supported when dataset_sink_mode is true. + If dataset_sink_mode is True, epoch of training should be equal to the count of repeat + operation in dataset processing. Otherwise, errors could occur since the amount of data + is not the amount training requires. + If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + Args: + epoch (int): Total number of iterations on the data. + train_dataset (Dataset): A training dataset iterator. If there is no + loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + returned and passed to the network. Otherwise, a tuple (data, label) should + be returned, and the data and label are passed to the network and loss + function respectively. + callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + Configure pynative mode, the training process will be performed with + dataset not sink. + + + Examples: + >>> dataset = get_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> loss_scale_manager = FixedLossScaleManager() + >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) + >>> model.train(2, dataset) + """ + repeat_count = train_dataset.get_repeat_count() + if epoch != repeat_count and dataset_sink_mode is True: + logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}") + check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) + _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) + + self._train(epoch, + train_dataset, + callbacks=callbacks, + dataset_sink_mode=dataset_sink_mode) + + def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): + """ + Evaluation. The data would be passed to network through dataset channel. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + list_callback (ListCallback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + """ + run_context = RunContext(cb_params) + + dataset_helper, eval_network = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=True) + self._eval_network = eval_network + cb_params.eval_network = self._eval_network + list_callback.begin(run_context) + + for inputs in dataset_helper: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + + outputs = self._eval_network(*inputs) + + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + + metrics = self._get_metrics() + cb_params.metrics = metrics + list_callback.end(run_context) + + return metrics + + def _eval_process(self, valid_dataset, list_callback=None, cb_params=None): + """ + Evaluation. The data would be passed to network directly. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + list_callback (ListCallback): Executor of callback list. Default: None. + cb_params (_InternalCallbackParam): Callback parameters. Default: None. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + """ + run_context = RunContext(cb_params) + list_callback.begin(run_context) + + dataset_helper, _ = self._exec_preprocess(self._eval_network, + is_train=False, + phase='eval', + dataset=valid_dataset, + dataset_sink_mode=False) + for next_element in dataset_helper: + cb_params.cur_step_num += 1 + list_callback.step_begin(run_context) + outputs = self._eval_network(*next_element) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + self._update_metrics(outputs) + + metrics = self._get_metrics() + cb_params.metrics = metrics + list_callback.end(run_context) + return metrics + + def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): + """ + Evaluation API where the iteration is controlled by python front-end. + + Configure to pynative mode, the evaluation will be performed with dataset non-sink mode. + + Note: + CPU is not supported when dataset_sink_mode is true. + If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features + of data will be transferred one by one. The limitation of data transmission per time is 256M. + + Args: + valid_dataset (Dataset): Dataset to evaluate the model. + callbacks (list): List of callback object. Callbacks which should be excuted + while training. Default: None. + dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. + + Returns: + Dict, returns the loss value & metrics values for the model in test mode. + + Examples: + >>> dataset = get_dataset() + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) + >>> model.eval(dataset) + """ + check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) + if not self._metric_fns: + raise ValueError("metric fn can not be None or empty.") + + list_callback = _build_callbacks(callbacks) + cb_params = _InternalCallbackParam() + cb_params.eval_network = self._eval_network + cb_params.valid_dataset = valid_dataset + cb_params.batch_num = valid_dataset.get_dataset_size() + cb_params.mode = "eval" + cb_params.cur_step_num = 0 + + self._eval_network.set_train(mode=False) + self._eval_network.phase = 'eval' + + self._clear_metrics() + + if dataset_sink_mode: + return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) + return self._eval_process(valid_dataset, list_callback, cb_params) + + def predict(self, *predict_data): + """ + Generates output predictions for the input samples. + + Data could be single tensor, or list of tensor, tuple of tensor. + + Note: + Batch data should be put together in one tensor. + + Args: + predict_data (Tensor): Tensor of predict data. can be array, list or tuple. + + Returns: + Tensor, array(s) of predictions. + + Examples: + >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) + >>> model = Model(Net()) + >>> model.predict(input_data) + """ + self._predict_network.set_train(False) + check_input_data(*predict_data, data_class=Tensor) + result = self._predict_network(*predict_data) + + check_output_data(result) + return result + + +__all__ = ["Model"] diff --git a/tests/st/networks/models/resnet50/src_thor/resnet.py b/tests/st/networks/models/resnet50/src_thor/resnet.py new file mode 100644 index 00000000000..88b99fb1613 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/resnet.py @@ -0,0 +1,359 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import math +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P + +from .thor_layer import Conv2d_Thor, Dense_Thor + + +def calculate_gain(nonlinearity, param=None): + """calculate_gain""" + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + res = 0 + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + res = 1 + elif nonlinearity == 'tanh': + res = 5.0 / 3 + elif nonlinearity == 'relu': + res = math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + res = math.sqrt(2.0 / (1 + negative_slope ** 2)) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + return res + + +def _calculate_fan_in_and_fan_out(tensor): + """_calculate_fan_in_and_fan_out""" + dimensions = len(tensor) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + if dimensions == 2: # Linear + fan_in = tensor[1] + fan_out = tensor[0] + else: + num_input_fmaps = tensor[1] + num_output_fmaps = tensor[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = tensor[2] * tensor[3] + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return np.random.normal(0, std, size=inputs_shape).astype(np.float32) + + +def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) + + +def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): + weight_shape = (out_channel, in_channel, 3, 3) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return Conv2d_Thor(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency) + + +def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): + weight_shape = (out_channel, in_channel, 1, 1) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return Conv2d_Thor(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency) + + +def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): + weight_shape = (out_channel, in_channel, 7, 7) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return Conv2d_Thor(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel, damping, loss_scale, frequency): + weight_shape = (out_channel, in_channel) + weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) + return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, + bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1, + damping=0.03, + loss_scale=1, + frequency=278): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale, + frequency=frequency) + self.bn1 = _bn(channel) + + self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale, + frequency=frequency) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale, + frequency=frequency) + self.bn3 = _bn_last(out_channel) + + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, + damping=damping, loss_scale=loss_scale, + frequency=frequency), + _bn(out_channel)]) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes, + damping, + loss_scale, + frequency): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0], + damping=damping, + loss_scale=loss_scale, + frequency=frequency) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1], + damping=damping, + loss_scale=loss_scale, + frequency=frequency) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2], damping=damping, + loss_scale=loss_scale, + frequency=frequency) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3], + damping=damping, + loss_scale=loss_scale, + frequency=frequency) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, + damping, loss_scale, frequency): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride, + damping=damping, loss_scale=loss_scale, frequency=frequency) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, + damping=damping, loss_scale=loss_scale, frequency=frequency) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1, _ = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + damping, + loss_scale, + frequency) diff --git a/tests/st/networks/models/resnet50/src_thor/thor.py b/tests/st/networks/models/resnet50/src_thor/thor.py new file mode 100644 index 00000000000..d4469a58271 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/thor.py @@ -0,0 +1,201 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""momentum""" +import mindspore.common.dtype as mstype +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.common.parameter import ParameterTuple +from mindspore.common.tensor import Tensor +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.parallel._utils import _get_device_num, _get_mirror_mean + +from .grad_reducer_thor import DistributedGradReducerThor + +momentum_opt = C.MultitypeFuncGraph("momentum_opt") + + +@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): + """Apply momentum optimizer to the weight parameter using Tensor.""" + success = True + success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + return success + + +op_add = P.AddN() +apply_decay = C.MultitypeFuncGraph("apply_decay") + + +@apply_decay.register("Number", "Bool", "Tensor", "Tensor") +def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): + """Get grad with weight_decay.""" + if if_apply: + return op_add((weight * weight_decay, gradient)) + return gradient + + +class THOR(Optimizer): + """THOR""" + + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, + loss_scale=1.0, + decay_filter=lambda x: x.name not in []): + super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale) + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") + self.params = self.parameters + self.moments = self.params.clone(prefix="moments", init='zeros') + self.hyper_map = C.HyperMap() + self.opt = P.ApplyMomentum() + self.matrix_A = ParameterTuple(matrix_A) + self.matrix_G = ParameterTuple(matrix_G) + self.A_inv_max = ParameterTuple(A_inv_max) + self.G_inv_max = ParameterTuple(G_inv_max) + self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast() + self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft() + self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight() + self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul() + self.transpose = P.Transpose() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.mul = P.Mul() + self.weight_idx = [] + for i in range(len(self.params)): + if "conv" in self.params[i].name or "end_point" in self.params[i].name: + self.weight_idx.append(i) + self.weight_idx.append(len(self.params)) + self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, + 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, + 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, + 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, + 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, + 1.0] + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) + self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) + self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) + self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) + self.matrix_A_inv = () + self.matrix_G_inv = () + self.matrix_max_inv = () + + for i in range(54): + self.matrix_max_inv = self.matrix_max_inv + ( + Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),) + self.log = P.Log() + self.exp = P.Exp() + self.sqrt = P.Sqrt() + self.matrix_max_inv = ParameterTuple(self.matrix_max_inv) + self.assign = P.Assign() + self.cast = P.Cast() + self.thor = True + self.weight_decay = weight_decay * loss_scale + self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + + def construct(self, gradients): + params = self.params + moments = self.moments + if self.thor: + matrix_A_allreduce = () + matrix_G_allreduce = () + matrix_A_max_allreduce = () + matrix_G_max_allreduce = () + for i in range(54): + g = gradients[i * 3] + matrix_A = self.matrix_A[i] + matrix_G = self.matrix_G[i] + A_max = self.A_inv_max[i] + G_max = self.G_inv_max[i] + matrix_A = F.depend(matrix_A, g) + matrix_G = F.depend(matrix_G, g) + A_max = F.depend(A_max, g) + G_max = F.depend(G_max, g) + matrix_A_allreduce = matrix_A_allreduce + (matrix_A,) + matrix_G_allreduce = matrix_G_allreduce + (matrix_G,) + matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,) + matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,) + matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce) + matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce) + matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce) + matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce) + new_grads = () + for i in range(54): + g = gradients[i * 3] + temp_a = matrix_A_allreduce[i] + temp_g = matrix_G_allreduce[i] + temp_a = self.cast(temp_a, mstype.float32) + temp_g = self.cast(temp_g, mstype.float32) + matrix_A_inv_max = self.log(matrix_A_max_allreduce[i]) + matrix_A_inv_max = self.mul(matrix_A_inv_max, -1) + matrix_A_inv_max = self.exp(matrix_A_inv_max) + temp_a = self.mul(temp_a, matrix_A_inv_max) + matrix_G_inv_max = self.log(matrix_G_max_allreduce[i]) + matrix_G_inv_max = self.mul(matrix_G_inv_max, -1) + matrix_G_inv_max = self.exp(matrix_G_inv_max) + temp_g = self.mul(temp_g, matrix_G_inv_max) + temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i]) + temp_max = self.mul(temp_max, self.feature_map[i]) + temp_a = self.cast(temp_a, mstype.float16) + temp_g = self.cast(temp_g, mstype.float16) + if i == 53: + g = self.cube_matmul_left_fc(temp_g, g) + g = self.cube_matmul_right_fc(g, temp_a, temp_max) + else: + g = self.cube_matmul_left(temp_g, g) + g = self.cube_matmul_right_mul(g, temp_a, temp_max) + fake_A = self.assign(self.matrix_A[i], temp_a) + fake_G = self.assign(self.matrix_G[i], temp_g) + fake_max = self.assign(self.matrix_max_inv[i], temp_max) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + g = F.depend(g, fake_max) + if i == 53: + new_grads = new_grads + (g,) + else: + new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) + gradients = new_grads + else: + new_grads = () + for i in range(54): + g = gradients[i * 3] + matrix_A = self.matrix_A[i] + matrix_G = self.matrix_G[i] + matrix_max = self.matrix_max_inv[i] + matrix_A = F.depend(matrix_A, g) + matrix_G = F.depend(matrix_G, g) + matrix_max = F.depend(matrix_max, g) + if i == 53: + g = self.cube_matmul_left_fc(matrix_G, g) + g = self.cube_matmul_right_fc(g, matrix_A, matrix_max) + new_grads = new_grads + (g,) + else: + g = self.cube_matmul_left(matrix_G, g) + g = self.cube_matmul_right_mul(g, matrix_A, matrix_max) + new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) + gradients = new_grads + + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, + params, gradients) + gradients = self.scale_grad(gradients) + lr = self.get_lr() + success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) + return success diff --git a/tests/st/networks/models/resnet50/src_thor/thor_layer.py b/tests/st/networks/models/resnet50/src_thor/thor_layer.py new file mode 100644 index 00000000000..6ef86de3b55 --- /dev/null +++ b/tests/st/networks/models/resnet50/src_thor/thor_layer.py @@ -0,0 +1,481 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""thor_layer""" +import numpy as np +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore._checkparam import check_bool, twice, check_int_positive +from mindspore._extends import cell_attr_register +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.nn.cell import Cell +from mindspore.nn.layer.activation import get_activation +from mindspore.ops import operations as P + +C0 = 16 + + +def caculate_device_shape(matrix_dim, channel, is_A): + ll = (0) + if is_A: + if channel // C0 == 0: + matrix_dim = (matrix_dim / channel) * C0 + ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) + else: + ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) + return ll + + +class _Conv(Cell): + r"""Applies a N-D convolution over an input signal composed of several input + planes. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + data_format, + has_bias, + weight_init, + bias_init, + ): + super(_Conv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.pad_mode = pad_mode + self.padding = padding + self.dilation = dilation + self.group = group + self.data_format = data_format + self.has_bias = has_bias + if not (isinstance(in_channels, int) and in_channels > 0): + raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op passed ' + + str(in_channels) + ', should be a int and greater than 0.') + if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \ + (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ + kernel_size[0] < 1 or kernel_size[1] < 1: + raise ValueError('Attr \'kernel_size\' of \'Conv2D\' Op passed ' + + str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.') + if in_channels % group != 0: + raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op must be divisible by ' + 'attr \'group\' of \'Conv2D\' Op.') + if out_channels % group != 0: + raise ValueError('Attr \'out_channels\' of \'Conv2D\' Op must be divisible by ' + 'attr \'group\' of \'Conv2D\' Op.') + + self.weight = Parameter(initializer( + weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight') + + if check_bool(has_bias): + self.bias = Parameter(_initializer( + bias_init, [out_channels]), name='bias') + else: + if bias_init != 'zeros': + logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.") + self.bias = None + + def construct(self, *inputs): + raise NotImplementedError + + +class Conv2d_Thor(_Conv): + """Conv2d_Thor""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + data_format='NCHW', + has_bias=False, + weight_init='normal', + damping=0.03, + loss_scale=1, + frequency=278, + bias_init='zeros'): + self.thor = True + ksizes = (1, kernel_size, kernel_size, 1) + self.hw = kernel_size * kernel_size + strides = (1, stride, stride, 1) + kernel_size = twice(kernel_size) + super(Conv2d_Thor, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + data_format, + has_bias, + weight_init, + bias_init, + ) + self.conv2d = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group + ) + + self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.transpose02314 = P.CusTranspose02314() + self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] + self.matrix_G_dim = self.out_channels + self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim, + self.in_channels, True) + self.matrix_G_device_shape, self.matrix_G_device_dim = caculate_device_shape(self.matrix_G_dim, + self.in_channels, False) + self.matrix_A_device_temp_shape = ( + self.matrix_A_device_shape[0], self.matrix_A_device_shape[2], self.matrix_A_device_shape[1], + self.matrix_A_device_shape[3]) + self.matrix_G_device_temp_shape = ( + self.matrix_G_device_shape[0], self.matrix_G_device_shape[2], self.matrix_G_device_shape[1], + self.matrix_G_device_shape[3]) + self.matrix_A_inv = Parameter( + Tensor(np.reshape(np.identity(self.matrix_A_device_dim).astype(np.float16), self.matrix_A_device_shape)), + name='matrix_A_inv', requires_grad=False) + self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) + self.matrix_G_inv = Parameter( + Tensor(np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape)), + name="matrix_G_inv", requires_grad=False) + + self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) + self.fake_G = Tensor( + np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape)) + + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.mul = P.Mul() + self.cast = P.Cast() + self.damping = Tensor(damping) + self.vector_matmul = P.CusBatchMatMul() + self.diag_block_dim = 128 + self.channels_slice_flag = False + if self.in_channels % C0 != 0: + self.channels_slice_flag = True + + self.padA_flag = False + if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \ + and self.matrix_A_dim > self.diag_block_dim: + self.padA_flag = True + pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim + self.padA = P.Pad(((0, pad_dim), (0, pad_dim))) + self.device_shape_pad_flag = False + if self.matrix_A_dim != self.matrix_A_device_dim: + self.device_shape_pad_flag = True + self.device_shape_pad = P.Pad(((0, 0), (0, C0 - self.in_channels), (0, 0), (0, C0 - self.in_channels))) + self.slice = P.Slice() + self.gather = P.GatherV2() + self.freq = Tensor(frequency, mstype.int32) + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.axis = 0 + + dampingA_dim = self.matrix_A_dim + if (self.matrix_A_dim % self.diag_block_dim) != 0 and self.matrix_A_dim > self.diag_block_dim: + dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim + dampingG_dim = self.matrix_G_dim + if (self.matrix_G_dim % self.diag_block_dim) != 0 and self.matrix_G_dim > self.diag_block_dim: + dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim + + self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32) + self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32) + self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim]) + self.fused_abs_max2 = P.CusFusedAbsMax1() + self.log = P.Log() + self.exp = P.Exp() + self.sqrt = P.Sqrt() + self.getG = P.InsertGradientOf(self.save_gradient) + + def save_gradient(self, dout): + """save_gradient""" + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, 32.0) + dout = self.transpose02314(dout) + dout_shape = self.shape(dout) + normalizer = dout_shape[0] + + matrix_G = self.cube_matmul(dout, dout) + normalizer = self.cast(normalizer, ms.float32) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, 0) + self.cov_step = self.cov_step + self.freq + damping_step = self.cast(damping_step, mstype.float32) + damping = self.mul(damping_step, 32.0 / normalizer) + damping = self.sqrt(damping) + dampingG = self.cast(self.dampingG, mstype.float32) + matrix_G = matrix_G + damping * dampingG + + matrix_G_inv = self.cholesky(matrix_G) + matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max) + self.G_inv_max = matrix_G_inv_max + matrix_G_inv = self.matrix_combine(matrix_G_inv) + matrix_G_inv = self.reshape(matrix_G_inv, self.matrix_G_device_temp_shape) + matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3)) + matrix_G = self.cast(matrix_G_inv, mstype.float16) + self.matrix_G_inv = matrix_G + return out + + def construct(self, x): + if self.thor: + matrix_A = self.img2col(x) + matrix_A_shape = self.shape(matrix_A) + normalizer = matrix_A_shape[0] + matrix_A = self.cube_matmul(matrix_A, matrix_A) + + if self.channels_slice_flag: + matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0)) + matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels)) + matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim)) + normalizer = self.cast(normalizer, ms.float32) + matrix_A = self.mul(matrix_A, 1.0 / normalizer) + if self.padA_flag: + matrix_A = self.padA(matrix_A) + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.mul(damping_step, 32.0 / normalizer) + damping = self.sqrt(damping) + damping_A = self.cast(self.dampingA, mstype.float32) + matrix_A = matrix_A + damping * damping_A + matrix_A_inv = self.cholesky(matrix_A) + matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) + matrix_A_inv_max = self.fused_abs_max1(matrix_A_inv) + matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max) + self.A_inv_max = matrix_A_inv_max + matrix_A_inv = self.matrix_combine(matrix_A_inv) + matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) + if self.padA_flag: + matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim, self.matrix_A_dim)) + + if self.device_shape_pad_flag: + matrix_A_inv = self.reshape(matrix_A_inv, (self.hw, self.in_channels, self.hw, self.in_channels)) + matrix_A_inv = self.device_shape_pad(matrix_A_inv) + matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape) + matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3)) + self.matrix_A_inv = matrix_A_inv + self.matrix_G_inv = self.fake_G + out = self.conv2d(x, self.weight) + out = self.getG(out) + else: + out = self.conv2d(x, self.weight) + + return out + + def extra_repr(self): + """extra_repr""" + s = 'input_channels={}, output_channels={}, kernel_size={},' \ + 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ + 'group={}, data_format={}, has_bias={},' \ + 'weight_init={}, bias_init={}'.format( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.pad_mode, + self.padding, + self.dilation, + self.group, + self.data_format, + self.has_bias, + self.weight, + self.bias) + + if self.has_bias: + s += ', bias={}'.format(self.bias) + return s + + +class Dense_Thor(Cell): + """Dense_Thor""" + + @cell_attr_register(attrs=['has_bias', 'activation']) + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + damping=0.03, + loss_scale=1, + frequency=278, + has_bias=True, + activation=None): + super(Dense_Thor, self).__init__() + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = check_bool(has_bias) + self.thor = True + if isinstance(weight_init, Tensor): + if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: + raise ValueError("weight_init shape error") + + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + raise ValueError("bias_init shape error") + + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + + self.matrix_A_inv = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), name='matrix_A_inv', + requires_grad=False) + self.matrix_G_inv = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), name="matrix_G_inv", + requires_grad=False) + self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)) + + self.matmul = P.MatMul(transpose_b=True) + self.cube_matmul = P.CusMatMulCube(transpose_a=True) + self.matrix_combine = P.CusMatrixCombine() + self.cholesky = P.CusCholeskyTrsm() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.mul = P.Mul() + self.cast = P.Cast() + self.damping = Tensor(damping) + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.vector_matmul = P.CusBatchMatMul() + self.pad = P.Pad(((0, 24), (0, 24))) + self.pad1 = P.Pad(((0, 8), (0, 8))) + self.slice = P.Slice() + self.gather = P.GatherV2() + self.assignadd = P.AssignAdd() + self.freq = Tensor(frequency, mstype.int32) + self.axis = 0 + self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) + self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) + self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) + self.fused_abs_max2 = P.CusFusedAbsMax1() + self.log = P.Log() + self.exp = P.Exp() + self.dampingA = Tensor(np.identity(2048), mstype.float32) + self.dampingG = Tensor(np.identity(1024), mstype.float32) + self.add = P.TensorAdd() + self.sqrt = P.Sqrt() + self.getG = P.InsertGradientOf(self.save_gradient) + + def save_gradient(self, dout): + """save_gradient""" + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, 32.0) + normalizer = 32 + matrix_G = self.cube_matmul(dout, dout) + normalizer = self.cast(normalizer, ms.float32) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + matrix_G = self.pad(matrix_G) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.freq + damping = self.sqrt(damping_step) + dampingG = self.cast(self.dampingG, mstype.float32) + matrix_G = matrix_G + damping * dampingG + matrix_G_inv = self.cholesky(matrix_G) + matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max1(matrix_G_inv) + matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max) + self.G_inv_max = matrix_G_inv_max + matrix_G_inv = self.matrix_combine(matrix_G_inv) + matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000)) + matrix_G_inv = self.pad1(matrix_G_inv) + matrix_G_inv_shape = self.shape(matrix_G_inv) + matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16)) + matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3)) + matrix_G_inv = self.cast(matrix_G_inv, mstype.float16) + self.matrix_G_inv = matrix_G_inv + return out + + def construct(self, x): + """construct""" + if self.thor: + inputs = self.cube_matmul(x, x) + normalizer = 32 + normalizer = self.cast(normalizer, ms.float32) + matrix_A = self.mul(inputs, 1.0 / normalizer) + + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.sqrt(damping_step) + dampingA = self.cast(self.dampingA, mstype.float32) + matrix_A = matrix_A + damping * dampingA + matrix_A_inv = self.cholesky(matrix_A) + matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv) + + matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv) + matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max) + self.A_inv_max = matrix_A_inv_max + + matrix_A_inv = self.matrix_combine(matrix_A_inv) + matrix_A_inv_shape = self.shape(matrix_A_inv) + matrix_A_inv = self.reshape(matrix_A_inv, (matrix_A_inv_shape[0] / 16, 16, matrix_A_inv_shape[0] / 16, 16)) + matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3)) + matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) + self.matrix_A_inv = matrix_A_inv + self.matrix_G_inv = self.fake_G + output = self.matmul(x, self.weight) + output = self.getG(output) + else: + output = self.matmul(x, self.weight) + + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.activation_flag: + return self.activation(output) + return output + + def extend_repr(self): + """extend_repr""" + str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ + .format(self.in_channels, self.out_channels, self.weight, self.has_bias) + if self.has_bias: + str_info = str_info + ', bias={}'.format(self.bias) + + if self.activation_flag: + str_info = str_info + ', activation={}'.format(self.activation) + + return str_info diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py new file mode 100644 index 00000000000..6b9b9bd7da2 --- /dev/null +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -0,0 +1,385 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train and evaluate resnet50 network on imagenet dataset""" + +import os +import time +from multiprocessing import Process, Queue +import pytest +import numpy as np + +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.train.model import Model, ParallelMode +from mindspore.train.callback import Callback +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.model_zoo.resnet import resnet50 +import mindspore.nn as nn +import mindspore.dataset as ds + +from tests.st.networks.models.resnet50.src.dataset import create_dataset +from tests.st.networks.models.resnet50.src.lr_generator import get_learning_rate +from tests.st.networks.models.resnet50.src.config import config +from tests.st.networks.models.resnet50.src.metric import DistAccuracy, ClassifyCorrectCell +from tests.st.networks.models.resnet50.src_thor.config import config as thor_config +from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model +from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor +from tests.st.networks.models.resnet50.src_thor.thor import THOR + + +MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json" +MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json" +dataset_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/train" +eval_path = "/home/workspace/mindspore_dataset/imagenet/imagenet_original/val" + +np.random.seed(1) +ds.config.set_seed(1) +os.environ['GLOG_v'] = str(2) + +def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): + """get_model_lr""" + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + for i in range(total_steps): + epoch = (i + 1) / steps_per_epoch + base = (1.0 - float(epoch) / total_epochs) ** decay + lr_local = lr_init * base + if epoch >= 39: + lr_local = lr_local * 0.5 + if epoch >= 40: + lr_local = lr_local * 0.5 + lr_each_step.append(lr_local) + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + return learning_rate + + +def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): + """get_model_damping""" + damping_each_step = [] + total_steps = steps_per_epoch * total_epochs + for step in range(total_steps): + epoch = (step + 1) / steps_per_epoch + damping_here = damping_init * (decay_rate ** (epoch / 10)) + damping_each_step.append(damping_here) + + current_step = global_step + damping_each_step = np.array(damping_each_step).astype(np.float32) + damping_now = damping_each_step[current_step:] + return damping_now + + +class LossGet(Callback): + def __init__(self, per_print_times, data_size): + super(LossGet, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self._loss = 0.0 + self.data_size = data_size + + def step_end(self, run_context): + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." + .format(cb_params.cur_epoch_num, cur_step_in_epoch)) + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + self._loss = loss + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + self._per_step_mseconds = epoch_mseconds / self.data_size + + def get_loss(self): + return self._loss + + def get_per_step_time(self): + return self._per_step_mseconds + + +def train_process(q, device_id, epoch_size, device_num, enable_hccl): + os.system("mkdir " + str(device_id)) + os.chdir(str(device_id)) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + context.set_context(device_id=device_id) + os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH + os.environ['RANK_ID'] = str(device_id) + os.environ['RANK_SIZE'] = str(device_num) + if enable_hccl: + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + init() + + # network + net = resnet50(class_num=config.class_num) + + # evaluation network + dist_eval_network = ClassifyCorrectCell(net) + + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + + # loss + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", smooth_factor=config.label_smooth_factor, + num_classes=config.class_num) + + # train dataset + dataset = create_dataset(dataset_path=dataset_path, do_train=True, + repeat_num=epoch_size, batch_size=config.batch_size) + + step_size = dataset.get_dataset_size() + eval_interval = config.eval_interval + dataset.__loop_size__ = step_size * eval_interval + + # evalutation dataset + eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, + repeat_num=epoch_size, batch_size=config.eval_batch_size) + + # loss scale + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + + # learning rate + lr = Tensor(get_learning_rate(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, + steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode)) + + # optimizer + decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, + net.trainable_params())) + no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + + if config.use_lars: + momentum = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + use_nesterov=config.use_nesterov) + opt = nn.LARS(momentum, epsilon=config.lars_epsilon, hyperpara=config.lars_coefficient, + weight_decay=config.weight_decay, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, + lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, + loss_scale=config.loss_scale) + + else: + opt = nn.Momentum(group_params, lr, config.momentum, + weight_decay=config.weight_decay, loss_scale=config.loss_scale, + use_nesterov=config.use_nesterov) + + # model + model = Model(net, loss_fn=loss, optimizer=opt, + loss_scale_manager=loss_scale, amp_level="O2", keep_batchnorm_fp32=False, + metrics={'acc': DistAccuracy(batch_size=config.eval_batch_size, device_num=device_num)}, + eval_network=dist_eval_network) + + # model init + print("init_start", device_id) + model.init(dataset, eval_dataset) + print("init_stop", device_id) + + # callbacks + loss_cb = LossGet(1, step_size) + + # train and eval + print("run_start", device_id) + acc = 0.0 + time_cost = 0.0 + for epoch_idx in range(0, int(epoch_size / eval_interval)): + model.train(1, dataset, callbacks=loss_cb) + eval_start = time.time() + output = model.eval(eval_dataset) + eval_cost = (time.time() - eval_start) * 1000 + acc = float(output["acc"]) + time_cost = loss_cb.get_per_step_time() + loss = loss_cb.get_loss() + print("the {} epoch's resnet result:\n " + "device{}, training loss {}, acc {}, " + "training per step cost {:.2f} ms, eval cost {:.2f} ms, total_cost {:.2f} ms".format( + epoch_idx, device_id, loss, acc, time_cost, eval_cost, time_cost * step_size + eval_cost)) + q.put({'acc': acc, 'cost': time_cost}) + + +def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): + os.system("mkdir " + str(device_id)) + os.chdir(str(device_id)) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + context.set_context(device_id=device_id) + os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH_2 + os.environ['RANK_ID'] = str(device_id - 4) + os.environ['RANK_SIZE'] = str(device_num) + if enable_hccl: + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True, parameter_broadcast=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + init() + + # network + damping = get_model_damping(0, 0.03, 0.87, 50, 5004) + net = resnet50_thor(class_num=thor_config.class_num, damping=damping, loss_scale=thor_config.loss_scale, + frequency=thor_config.frequency) + + # evaluation network + dist_eval_network = ClassifyCorrectCell(net) + + if not thor_config.label_smooth: + thor_config.label_smooth_factor = 0.0 + + # loss + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=thor_config.label_smooth_factor, + num_classes=thor_config.class_num) + + # train dataset + dataset = create_dataset(dataset_path=dataset_path, do_train=True, + repeat_num=epoch_size, batch_size=thor_config.batch_size) + + step_size = dataset.get_dataset_size() + eval_interval = thor_config.eval_interval + + # evalutation dataset + eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, + repeat_num=epoch_size, batch_size=thor_config.eval_batch_size) + + # loss scale + loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False) + + # learning rate + lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) + + # optimizer + opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, thor_config.momentum, + filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), + filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), + filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), + filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), + thor_config.weight_decay, thor_config.loss_scale) + + # model + model = THOR_Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level="O2", + keep_batchnorm_fp32=False, + metrics={'acc': DistAccuracy(batch_size=thor_config.eval_batch_size, device_num=device_num)}, + eval_network=dist_eval_network, frequency=thor_config.frequency) + + # model init + print("init_start", device_id) + model.init(dataset, eval_dataset) + print("init_stop", device_id) + + # callbacks + loss_cb = LossGet(1, step_size) + + # train and eval + acc = 0.0 + time_cost = 0.0 + print("run_start", device_id) + for epoch_idx in range(0, int(epoch_size / eval_interval)): + model.train(eval_interval, dataset, callbacks=loss_cb) + eval_start = time.time() + output = model.eval(eval_dataset) + eval_cost = (time.time() - eval_start) * 1000 + acc = float(output["acc"]) + time_cost = loss_cb.get_per_step_time() + loss = loss_cb.get_loss() + print("the {} epoch's resnet result:\n " + "device{}, training loss {}, acc {}, " + "training per step cost {:.2f} ms, eval cost {:.2f} ms, total_cost {:.2f} ms".format( + epoch_idx, device_id, loss, acc, time_cost, eval_cost, time_cost * step_size + eval_cost)) + q.put({'acc': acc, 'cost': time_cost}) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_single +def test_resnet_and_resnet_thor_imagenet_4p(): + q = Queue() + q2 = Queue() + + # resnet50 + device_num = 4 + epoch_size = 2 + epoch_size_2 = 1 + enable_hccl = True + process = [] + process2 = [] + for i in range(device_num): + device_id = i + process.append(Process(target=train_process, + args=(q, device_id, epoch_size, device_num, enable_hccl))) + process2.append(Process(target=train_process_thor, + args=(q2, device_id + 4, epoch_size_2, device_num, enable_hccl))) + + for i in range(device_num): + process[i].start() + process2[i].start() + + print("Waiting for all subprocesses done...") + + for i in range(device_num): + process[i].join() + process2[i].join() + + # resnet + acc = 0.0 + cost = 0.0 + for i in range(device_num): + output = q.get() + acc += output['acc'] + cost += output['cost'] + acc = acc / device_num + cost = cost / device_num + + for i in range(device_num): + os.system("rm -rf " + str(i)) + print("End training...") + assert acc > 0.13 + assert cost < 21 + + # THOR + thor_acc = 0.0 + thor_cost = 0.0 + for i in range(device_num): + output = q2.get() + thor_acc += output['acc'] + thor_cost += output['cost'] + thor_acc = thor_acc / device_num + thor_cost = thor_cost / device_num + + for i in range(4, device_num + 4): + os.system("rm -rf " + str(i)) + print("End training...") + assert thor_acc > 0.22 + assert thor_cost < 22 diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py index b9a0eaac533..672d17c72b8 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_1p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py @@ -15,8 +15,6 @@ import os import random - -import pytest import numpy as np from resnet import resnet50 @@ -152,10 +150,7 @@ def train_process(epoch_size, num_classes, batch_size): print("result: ", res) return res -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard + def test_resnet_cifar_1p(): epoch_size = 1 num_classes = 10 diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py index 07a35f1591f..a13f367b9f3 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_8p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py @@ -17,7 +17,7 @@ import os import random from multiprocessing import Process, Queue import numpy as np -import pytest + from resnet import resnet50 import mindspore.common.dtype as mstype import mindspore.dataset as ds @@ -173,10 +173,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, q.put(loss_cb.get_loss()) -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_single def test_resnet_cifar_8p(): q = Queue() device_num = 8