!4414 THOR optimizer for GPU training

Merge pull request !4414 from wangmin0104/master
This commit is contained in:
mindspore-ci-bot 2020-08-14 09:13:50 +08:00 committed by Gitee
commit 1002ee4887
17 changed files with 1104 additions and 1103 deletions

View File

@ -24,22 +24,24 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
.
├── resnet_thor
├── README.md
├── src
├──scripts
├── run_distribute_train.sh # launch distributed training for Ascend
└── run_eval.sh # launch infering for Ascend
├── run_distribute_train_gpu.sh # launch distributed training for GPU
└── run_eval_gpu.sh # launch infering for GPU
├──src
├── crossentropy.py # CrossEntropy loss function
├── config.py # parameter configuration
├── resnet50.py # resnet50 backbone
├── dataset_helper.py # dataset help for minddata dataset
├── grad_reducer_thor.py # grad reducer for thor
├── model_thor.py # model
├── model_thor.py # model for train
├── resnet_thor.py # resnet50_thor backone
├── thor.py # thor
├── thor.py # thor optimizer
├── thor_layer.py # thor layer
└── dataset_imagenet.py # data preprocessing
├── scripts
├── run_distribute_train.sh # launch distributed training(8 pcs)
└── run_eval.sh # launch infering
└── dataset.py # data preprocessing
├── eval.py # infer script
└── train.py # train script
```
@ -48,26 +50,30 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
Parameters for both training and inference can be set in config.py.
```
"class_num": 1000, # dataset class number
"class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor
"loss_scale": 128, # loss scale
"momentum": 0.9, # momentum of THOR optimizer
"weight_decay": 5e-4, # weight decay
"epoch_size": 45, # only valid for taining, which is always 1 for inference
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch
"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the checkpoint will be saved every epoch
"keep_checkpoint_max": 15, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0.045, # learning rate init value
"lr_decay": 6, # learning rate decay rate value
"lr_end_epoch": 70, # learning rate end epoch value
"damping_init": 0.03, # damping init value for Fisher information matrix
"damping_decay": 0.87, # damping decay rate
"frequency": 834, # the step interval to update second-order information matrix
```
## Running the example
### 1 Running on Ascend 910
### Train
#### Usage
@ -82,10 +88,10 @@ Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM]
```bash
# distributed training example(8 pcs)
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc 8
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training_ascend.html).
#### Result
@ -126,3 +132,35 @@ Inference result will be stored in the example path, whose folder name is "infer
```
result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt
```
### 2 Running on GPU
### Train
```
# distributed training example
sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]
```
#### Result
```
# distribute training result(8 pcs)
epoch: 1 step: 5004, loss is 4.3069
epoch: 2 step: 5004, loss is 3.5695
epoch: 3 step: 5004, loss is 3.5893
epoch: 4 step: 5004, loss is 3.1987
epoch: 5 step: 5004, loss is 3.3526
......
epoch: 40 step: 5004, loss is 1.9482
epoch: 41 step: 5004, loss is 1.8950
epoch: 42 step: 5004, loss is 1.9023
......
```
### Infer
```
# infer example
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Result
```
result: {'acc': 0.760143245838668} ckpt_0/resnet-40_5004.ckpt
```

View File

@ -12,51 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
"""train resnet."""
import os
import random
import argparse
import numpy as np
from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset_imagenet import create_dataset
from src.config import config
from src.crossentropy import CrossEntropy
from src.resnet50 import resnet50
from src.config import config
from src.dataset import create_dataset
from src.resnet_thor import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
if __name__ == '__main__':
target = args_opt.device_target
net = resnet50(class_num=config.class_num)
if not config.label_smooth:
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if target != "GPU":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
target=target)
# define net
net = resnet(class_num=config.class_num)
net.add_flags_recursive(thor=False)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
keys = list(param_dict.keys())
for key in keys:
if "damping" in key:
param_dict.pop(key)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_eval:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -52,6 +52,6 @@ do
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
cd ..
done

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
ulimit -u unlimited
export DEVICE_NUM=$2
export RANK_SIZE=$2
rm -rf ./train_parallel
mkdir ./train_parallel
cp ../*.py ./train_parallel
cp *.sh ./train_parallel
cp -r ../src ./train_parallel
cd ./train_parallel || exit
mpirun -n $RANK_SIZE \
python train.py --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &

13
model_zoo/official/cv/resnet_thor/scripts/run_eval.sh Normal file → Executable file
View File

@ -20,6 +20,7 @@ then
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
@ -44,9 +45,6 @@ then
exit 1
fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
@ -58,10 +56,11 @@ then
rm -rf ./eval
fi
mkdir ./eval
cp *.py ./eval
cp -r ./src ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
cd ..

View File

@ -17,21 +17,46 @@ network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config for resnet50, imagenet2012, Ascend 910
config = ed({
"class_num": 1000,
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True,
"save_checkpoint_steps": 5004,
"keep_checkpoint_max": 20,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"label_smooth": 1,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"frequency": 834
"lr_init": 0.045,
"lr_decay": 6,
"lr_end_epoch": 70,
"damping_init": 0.03,
"damping_decay": 0.87,
"frequency": 834,
})
# config for resnet50, imagenet2012, GPU
config_gpu = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.04,
"lr_decay": 5,
"lr_end_epoch": 58,
"damping_init": 0.02,
"damping_decay": 0.87,
"frequency": 834,
})

View File

@ -28,13 +28,10 @@ class CrossEntropy(_Loss):
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
# self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
# one_hot_label = self.onehot(self.cast(label, mstype.int32),
# F.shape(logit)[1], self.on_value, self.off_value)、
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)

View File

@ -16,30 +16,36 @@
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.c_transforms as V_C
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval dataset
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
@ -47,29 +53,28 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
transform_img = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
transform_img = [
V_C.Decode(),
V_C.Resize((256, 256)),
V_C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
# type_cast_op = C2.TypeCast(mstype.float16)
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
# ds = ds.shuffle(buffer_size=config.buffer_size)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
@ -78,3 +83,18 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -13,34 +13,47 @@
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
import math
import os
from mindspore._checkparam import check_bool, check_int
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full
def _send_data(dataset):
def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send()
exec_dataset.send(epoch_num)
dataset.__has_sent__ = True
def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send(epoch_num)
class DatasetHelper:
"""
Help function to use the Minddata dataset.
Help function to use the MindData dataset.
According to different context, change the iter of dataset, to use the same for loop in different context.
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts.
Note:
The iter of DatasetHelper will give one epoch data.
The iteration of DatasetHelper will provide one epoch data.
Args:
dataset (DataSet): The dataset.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
Default: True.
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
@ -48,81 +61,116 @@ class DatasetHelper:
>>> outputs = network(*inputs)
"""
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
check_bool(dataset_sink_mode)
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, sink_size, epoch_num)
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
def __iter__(self):
return self.iter.__iter__()
# A temp solution for loop sink. Delete later
def types_shapes(self):
"""Get the types and shapes from dataset on current config."""
"""Get the types and shapes from dataset on the current configuration."""
return self.iter.types_shapes()
def loop_size(self):
"""Get loop_size for every iteration."""
return self.iter.loop_size
def sink_size(self):
"""Get sink_size for each iteration."""
return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter:
"""Base iter for dataset help"""
"""Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = 1
def __init__(self, dataset):
self.loop_size = 1
if not hasattr(dataset, '__ME_INITED__'):
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
if not hasattr(dataset, '__TRANSFER_DATASET__'):
if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
_send_data(dataset, epoch_num)
else:
_send_data(dataset)
_send_data_no_flag(dataset, epoch_num)
self.ind = 0
self.dataset = dataset
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
def __iter__(self):
self.ind = 0
self.index = 0
return self
def __next__(self):
if self.ind >= self.loop_count:
if self.index >= self.sink_count:
raise StopIteration()
self.ind += 1
self.index += 1
return self.op()
def types_shapes(self):
return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset):
loop_count = 1
def get_sink_count(self, dataset):
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
if dataset.get_dataset_size() % loop_size != 0:
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size() / loop_size)
return loop_count
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)"""
def __init__(self, dataset, iter_first_order):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
loop_size = dataset.__loop_size__ + iter_first_order
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run. Now only support LoopSink.
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
super().__init__(dataset, sink_size, epoch_num)
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ + iter_first_order
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2
self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
@ -130,3 +178,16 @@ class _DatasetIterMSLoopSink(_DatasetIter):
return tuple()
self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

View File

@ -174,10 +174,6 @@ class DistributedGradReducerThor(Cell):
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean:
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad

View File

@ -14,27 +14,19 @@
# ============================================================================
"""Model."""
import numpy as np
import math
from mindspore.train.callback import RunContext
from mindspore import context
from mindspore import log as logger
from mindspore import nn
from mindspore._c_expression import init_exec_dataset
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor
from mindspore.nn.metrics import Loss
from mindspore.nn.metrics import get_metrics
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from mindspore.train._utils import _to_full_tensor
from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full
from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset
from src.dataset_helper import DatasetHelper
def _convert_type(types):
"""
Convert from numpy type to tensor type.
@ -76,194 +68,52 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
need_run=False)
class Model:
class Model_Thor(Model):
"""
High-Level API for Training or Testing.
`Model` groups layers into an object with training and inference features.
Args:
network (Cell): The training or testing network.
network (Cell): A training or testing network.
loss_fn (Cell): Objective function, if loss_fn is None, the
network should contain the logic of loss and grads calculation, and the logic
of parallel if needed. Default: None.
optimizer (Cell): Optimizer for updating the weights. Default: None.
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
training and testing. eg: {'accuracy', 'recall'}. Default: None.
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
`eval_network`. Default: None.
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
elements, representing the positions of loss value, predict value and label, the loss
value would be passed to `Loss` metric, predict value and label would be passed to other
metric. Default: None.
elements, including the positions of loss value, predicted value and label. The loss
value would be passed to the `Loss` metric, the predicted value and label would be passed
to other metric. Default: None.
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
precision training. Supports [O0, O2]. Default: "O0".
precision training. Supports [O0, O2, O3]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
O2 is recommended on GPU, O3 is recommended on Ascend.
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
scale the loss by LossScaleManager. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
will be overwritten. Default: True.
"""
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs):
self._network = network
self._loss_fn = loss_fn
self._optimizer = optimizer
self._loss_scale_manager = None
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
self._amp_level = amp_level
self._process_amp_args(kwargs)
self._parallel_mode = _get_parallel_mode()
self._device_number = _get_device_num()
self._global_rank = _get_global_rank()
self._parameter_broadcast = _get_parameter_broadcast()
eval_indexes=None, amp_level="O0", frequency=834, **kwargs):
super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
eval_indexes, amp_level, **kwargs)
self._frequency = frequency
self._stop_epoch = stop_epoch
self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()
def _process_amp_args(self, kwargs):
if self._amp_level == "O0":
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
raise ValueError(f"Unsupport arg '{arg}'")
def _build_train_network(self):
"""Build train network"""
network = self._network
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
def _build_eval_network(self, metrics, eval_network, eval_indexes):
"""Build the network for evaluation."""
self._metric_fns = get_metrics(metrics)
if not self._metric_fns:
return
if eval_network is not None:
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
must be three. But got {}".format(eval_indexes))
self._eval_network = eval_network
self._eval_indexes = eval_indexes
else:
if self._loss_fn is None:
raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._eval_network.set_auto_parallel()
def _build_predict_network(self):
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
self._predict_network.set_auto_parallel()
def _clear_metrics(self):
"""Clear metrics local values."""
for metric in self._metric_fns.values():
metric.clear()
def _update_metrics(self, outputs):
"""Update metrics local values."""
if not isinstance(outputs, tuple):
raise ValueError("The `outputs` is not tuple.")
if self._eval_indexes is not None and len(outputs) < 3:
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
but got {}".format(len(outputs)))
for metric in self._metric_fns.values():
if self._eval_indexes is None:
metric.update(*outputs)
else:
if isinstance(metric, Loss):
metric.update(outputs[self._eval_indexes[0]])
else:
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
def _get_metrics(self):
"""Get metrics local values."""
metrics = dict()
for key, value in self._metric_fns.items():
metrics[key] = value.eval()
return metrics
def _get_scaling_sens(self):
"""get the scaling sens"""
scaling_sens = 1
if self._loss_scale_manager is not None:
scaling_sens = self._loss_scale_manager.get_loss_scale()
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
scaling_sens /= self._device_number
return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order):
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
epoch_num=1, iter_first_order=1):
"""Initializes dataset."""
need_wrap = False
if dataset_sink_mode:
@ -275,7 +125,7 @@ class Model:
if not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
# remove later to deal with loop sink
if need_wrap:
@ -283,133 +133,31 @@ class Model:
network.set_train(is_train)
network.phase = phase
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return dataset_helper, network
def init(self, train_dataset=None, valid_dataset=None):
"""
Initializes compute graphs and data graphs with sink mode.
Note:
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
Args:
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
initialized. Default: None.
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
be initialized, and `metrics` in `Model` can not be None. Default: None.
Examples:
>>> train_dataset = get_train_dataset()
>>> valid_dataset = get_valid_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
>>> model.init(train_dataset, valid_dataset)
>>> model.train(2, train_dataset)
>>> model.eval(valid_dataset)
"""
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
if not train_dataset and not valid_dataset:
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
_device_number_check(self._parallel_mode, self._device_number)
if train_dataset:
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train_network.set_train()
self._train_network.phase = 'train'
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True)
self._train_network = train_network
for inputs in train_dataset_helper:
self._train_network.compile(*inputs)
break
if valid_dataset:
if not self._metric_fns:
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
for inputs in valid_dataset_helper:
self._eval_network.compile(*inputs)
break
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
"""
Training.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
returned and passed to the network. Otherwise, a tuple (data, label) will
be returned, and the data and label are passed to the network and loss
function respectively.
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
"""
epoch = check_int_positive(epoch)
self._train_network.set_train()
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
# build callback list
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
cb_params.batch_num = train_dataset.get_dataset_size()
cb_params.mode = "train"
cb_params.loss_fn = self._loss_fn
cb_params.optimizer = self._optimizer
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = callbacks
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Training process. The data would be passed to network through dataset channel.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
be returned. The data and label would be passed to the network and loss
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data in each sink. Default: -1.
"""
if sink_size == -1:
epoch_num = epoch
else:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
iter_first_order = self._frequency - 1
iter_second_order = 1
train_dataset.__loop_size__ = iter_second_order
@ -418,308 +166,82 @@ class Model:
phase='train',
dataset=train_dataset,
dataset_sink_mode=True,
sink_size=sink_size,
epoch_num=epoch_num,
iter_first_order=iter_first_order)
self._train_network = train_network
cb_params.train_network = self._train_network
cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params)
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
has_do_dataset_init = False
switch_branch_one = True
index_first_order = 0
train_network_init_flag = True
has_do_dataset_init = False
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
if switch_branch_one:
cb_params.cur_step_num += loop_size
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
if context.get_context("device_target") == "GPU":
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
else:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
index_first_order += 1
if index_first_order == iter_first_order:
index_first_order = 0
switch_branch_one = not switch_branch_one
list_callback.step_end(run_context)
else:
cb_params.cur_step_num += iter_first_order
self._train_network.add_flags_recursive(thor=False)
self._train_network.phase = 'train1'
if not has_do_dataset_init:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
has_do_dataset_init = True
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
else:
cb_params.cur_step_num += iter_first_order
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
if not has_do_dataset_init:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
has_do_dataset_init = True
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
dataset_helper.stop_send()
list_callback.end(run_context)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
Training process. The data would be passed to network directly.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=False)
cb_params.cur_step_num = 0
run_context = RunContext(cb_params)
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
for next_element in dataset_helper:
len_element = len(next_element)
if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
_, overflow, _ = outputs
overflow = np.all(overflow.asnumpy())
self._loss_scale_manager.update_loss_scale(overflow)
list_callback.step_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
train_dataset.reset()
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
"""
Training API where the iteration is controlled by python front-end.
When setting pynative mode, the training process will be performed with dataset not sink.
Note:
CPU is not supported when dataset_sink_mode is true.
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
operation in dataset processing. Otherwise, errors could occur since the amount of data
is not the amount training requires.
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
Examples:
>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
repeat_count = train_dataset.get_repeat_count()
if epoch != repeat_count and dataset_sink_mode is True:
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network through dataset channel.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
"""
run_context = RunContext(cb_params)
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
cb_params.eval_network = self._eval_network
list_callback.begin(run_context)
for inputs in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
self._update_metrics(outputs)
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
return metrics
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network directly.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
"""
run_context = RunContext(cb_params)
list_callback.begin(run_context)
dataset_helper, _ = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=False)
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
self._update_metrics(outputs)
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
return metrics
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
"""
Evaluation API where the iteration is controlled by python front-end.
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
Note:
CPU is not supported when dataset_sink_mode is true.
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
callbacks (list): List of callback object. Callbacks which should be excuted
while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
Examples:
>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)
"""
check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
cb_params.batch_num = valid_dataset.get_dataset_size()
cb_params.mode = "eval"
cb_params.cur_step_num = 0
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
self._clear_metrics()
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""
Generates output predictions for the input samples.
Data could be single tensor, or list of tensor, tuple of tensor.
Note:
Batch data should be put together in one tensor.
Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
Returns:
Tensor, array(s) of predictions.
Examples:
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.predict(input_data)
"""
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
result = self._predict_network(*predict_data)
check_output_data(result)
return result
__all__ = ["Model"]
__all__ = ["Model_Thor"]

View File

@ -1,262 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
_bn(out_channel)])
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

View File

@ -18,8 +18,9 @@ import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore import context
from src.thor_layer import Conv2d_Thor, Dense_Thor
from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU
def calculate_gain(nonlinearity, param=None):
@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu')
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _bn(channel):
@ -120,14 +144,21 @@ def _bn(channel):
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, damping, loss_scale, frequency):
def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency)
if context.get_context('device_target') == "Ascend":
layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
else:
layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
return layer
class ResidualBlock(nn.Cell):
@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell):
stride=1,
damping=0.03,
loss_scale=1,
frequency=278):
frequency=278,
batch_size=32):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency)
frequency=frequency, batch_size=batch_size)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
frequency=frequency, batch_size=batch_size)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale,
frequency=frequency),
frequency=frequency,
batch_size=batch_size),
_bn(out_channel)])
self.add = P.TensorAdd()
@ -239,16 +272,19 @@ class ResNet(nn.Cell):
num_classes,
damping,
loss_scale,
frequency):
frequency,
batch_size):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency)
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
# self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
@ -257,7 +293,8 @@ class ResNet(nn.Cell):
stride=strides[0],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
frequency=frequency,
batch_size=batch_size)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
@ -265,14 +302,16 @@ class ResNet(nn.Cell):
stride=strides[1],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
frequency=frequency,
batch_size=batch_size)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2], damping=damping,
loss_scale=loss_scale,
frequency=frequency)
frequency=frequency,
batch_size=batch_size)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
@ -280,14 +319,16 @@ class ResNet(nn.Cell):
stride=strides[3],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
frequency=frequency,
batch_size=batch_size)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency)
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency):
damping, loss_scale, frequency, batch_size):
"""
Make stage network of ResNet.
@ -307,12 +348,14 @@ class ResNet(nn.Cell):
layers = []
resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency)
damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency)
damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block)
return nn.SequentialCell(layers)
@ -321,7 +364,7 @@ class ResNet(nn.Cell):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1, _ = self.maxpool(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
@ -335,7 +378,7 @@ class ResNet(nn.Cell):
return out
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
"""
Get ResNet50 neural network.
@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
class_num,
damping,
loss_scale,
frequency)
frequency,
batch_size)

View File

@ -12,27 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""momentum"""
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
"""THOR"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from src.grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
@ -46,6 +39,119 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
class THOR_GPU(Optimizer):
"""
THOR
"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
self.feature_map_new = [x ** 0.5 for x in self.feature_map]
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.assign = P.Assign()
self.mul = P.Mul()
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=128)
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
new_grads = ()
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_A = self.mul(matrix_A, self.feature_map_new[i])
matrix_G = self.mul(matrix_G, self.feature_map_new[i])
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = matrix_A_allreduce[i]
matrix_G = matrix_G_allreduce[i]
g = self.update_gradient(matrix_G, g, matrix_A)
fake_A = self.assign(self.matrix_A[i], matrix_A)
fake_G = self.assign(self.matrix_G[i], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
else:
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
g = self.update_gradient(matrix_G, g, matrix_A)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class THOR(Optimizer):
"""THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
@ -195,5 +301,5 @@ class THOR(Optimizer):
params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success

View File

@ -15,7 +15,6 @@
"""thor_layer"""
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive
from mindspore._extends import cell_attr_register
@ -25,8 +24,10 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P
C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
@ -37,6 +38,31 @@ def caculate_device_shape(matrix_dim, channel, is_A):
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
split_dimA = split_dim
split_dimG = split_dim
if matrix_A_dim % split_dim == 0:
batch_w = matrix_A_dim // split_dim
else:
if matrix_A_dim < split_dim:
batch_w = 1
split_dimA = matrix_A_dim
else:
batch_w = matrix_A_dim // split_dim + 1
if matrix_G_dim % split_dim == 0:
batch_h = matrix_G_dim // split_dim
else:
if matrix_G_dim < split_dim:
batch_h = 1
split_dimG = matrix_G_dim
else:
batch_h = matrix_G_dim // split_dim + 1
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
matrix_G_shape = (batch_h, split_dimG, split_dimG)
return matrix_A_shape, matrix_G_shape
class _Conv(Cell):
r"""Applies a N-D convolution over an input signal composed of several input
planes.
@ -97,6 +123,286 @@ class _Conv(Cell):
raise NotImplementedError
class Conv2d_Thor_GPU(_Conv):
"""Conv2d_Thor"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
data_format='NCHW',
has_bias=False,
weight_init='normal',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
bias_init='zeros'):
self.thor = True
self.hw = kernel_size * kernel_size
kernel_size = twice(kernel_size)
super(Conv2d_Thor_GPU, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group
)
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.matrix_A_dim, self.matrix_G_dim, split_dim)
self.matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32),
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32),
name='matrix_A_inv', requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
self.matmul = P.MatMul(transpose_b=True)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.getG = P.InsertGradientOf(self.save_gradient)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.transpose = P.Transpose()
self.cast = P.Cast()
self.gather = P.GatherV2()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.sqrt = P.Sqrt()
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32)
self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32)
self.cholesky = P.Cholesky(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout = self.reduce_mean(dout, 0)
dout_shape = self.shape(dout)
dout = self.reshape(dout, (dout_shape[0], -1))
dout_shape = self.shape(dout)
normalizer = dout_shape[1]
dout = self.cast(dout, mstype.float32)
matrix_G = self.matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
matrix_A = self.reshape(matrix_A, (matrix_A_shape[0]*matrix_A_shape[1]*matrix_A_shape[2],
matrix_A_shape[3], -1))
matrix_A = self.reduce_mean(matrix_A, 1)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[1]
matrix_A = self.cast(matrix_A, mstype.float32)
matrix_A = self.matmul(matrix_A, matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
out = self.conv2d(x, self.weight)
out = self.getG(out)
else:
out = self.conv2d(x, self.weight)
return out
def extra_repr(self):
"""extra_repr"""
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, data_format={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.data_format,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
class Dense_Thor_GPU(Cell):
"""Dense_Thor"""
@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
has_bias=True,
activation=None):
super(Dense_Thor_GPU, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.in_channels, self.out_channels, split_dim)
self.matrix_A_inv = Parameter(Tensor(np.zeros(matrix_A_shape).astype(np.float32)),
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros(matrix_G_shape).astype(np.float32)),
name="matrix_G_inv", requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.mul = P.Mul()
self.cube_matmul = P.MatMul(transpose_a=True)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.getG = P.InsertGradientOf(self.save_gradient)
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
self.cast = P.Cast()
self.gather = P.GatherV2()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.add = P.TensorAdd()
self.sqrt = P.Sqrt()
self.cholesky = P.Cholesky(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
dout = self.cast(dout, mstype.float32)
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.sqrt(damping_step)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
"""construct"""
if self.thor:
inputs = self.cast(x, mstype.float32)
inputs = self.cube_matmul(inputs, inputs)
inputs_shape = self.shape(inputs)
normalizer = inputs_shape[0]
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
output = self.matmul(x, self.weight)
output = self.getG(output)
else:
output = self.matmul(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output
def extend_repr(self):
"""extend_repr"""
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
class Conv2d_Thor(_Conv):
"""Conv2d_Thor"""
def __init__(self,
@ -114,6 +420,7 @@ class Conv2d_Thor(_Conv):
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
bias_init='zeros'):
self.thor = True
ksizes = (1, kernel_size, kernel_size, 1)
@ -143,7 +450,7 @@ class Conv2d_Thor(_Conv):
dilation=self.dilation,
group=self.group
)
self.batch_size = batch_size
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
@ -228,7 +535,7 @@ class Conv2d_Thor(_Conv):
normalizer = dout_shape[0]
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32)
normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
self.cov_step = self.cov_step + self.freq
@ -261,7 +568,7 @@ class Conv2d_Thor(_Conv):
matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0))
matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
normalizer = self.cast(normalizer, ms.float32)
normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
if self.padA_flag:
matrix_A = self.padA(matrix_A)
@ -330,6 +637,7 @@ class Dense_Thor(Cell):
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
has_bias=True,
activation=None):
super(Dense_Thor, self).__init__()
@ -337,6 +645,7 @@ class Dense_Thor(Cell):
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.thor = True
self.batch_size = batch_size
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
@ -376,8 +685,8 @@ class Dense_Thor(Cell):
self.damping = Tensor(damping)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.vector_matmul = P.CusBatchMatMul()
self.pad = P.Pad(((0, 24), (0, 24)))
self.pad1 = P.Pad(((0, 8), (0, 8)))
self.pad = P.Pad(((0, 23), (0, 23)))
self.pad1 = P.Pad(((0, 7), (0, 7)))
self.slice = P.Slice()
self.gather = P.GatherV2()
self.assignadd = P.AssignAdd()
@ -385,7 +694,7 @@ class Dense_Thor(Cell):
self.axis = 0
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000])
self.fused_abs_max1 = P.CusFusedAbsMax1([1001, 1001])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
@ -402,7 +711,7 @@ class Dense_Thor(Cell):
dout = self.mul(dout, 32.0)
normalizer = 32
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32)
normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
matrix_G = self.pad(matrix_G)
damping_step = self.gather(self.damping, self.cov_step, 0)
@ -417,7 +726,7 @@ class Dense_Thor(Cell):
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000))
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
matrix_G_inv = self.pad1(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16))
@ -431,7 +740,7 @@ class Dense_Thor(Cell):
if self.thor:
inputs = self.cube_matmul(x, x)
normalizer = 32
normalizer = self.cast(normalizer, ms.float32)
normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)

View File

@ -12,44 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import argparse
"""train resnet."""
import os
import random
import argparse
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode
from src.model_thor import Model
from src.resnet_thor import resnet50
from src.thor import THOR
from src.config import config
from src.crossentropy import CrossEntropy
from src.dataset_imagenet import create_dataset
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size
random.seed(1)
np.random.seed(1)
from src.model_thor import Model_Thor as Model
from src.resnet_thor import resnet50
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_num', type=int, default=1, help='Device num')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
if args_opt.device_target == "Ascend":
from src.thor import THOR
from src.config import config
else:
from src.thor import THOR_GPU as THOR
from src.config import config_gpu as config
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
@ -57,9 +59,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base
if epoch >= 39:
if epoch >= decay_epochs:
lr_local = lr_local * 0.5
if epoch >= 40:
if epoch >= decay_epochs + 1:
lr_local = lr_local * 0.5
lr_each_step.append(lr_local)
current_step = global_step
@ -76,7 +78,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here)
current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:]
@ -84,49 +85,70 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
init()
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
epoch_size = config.epoch_size
damping = get_model_damping(0, 0.03, 0.87, 50, 5004)
if args_opt.run_distribute:
# Ascend target
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
# define net
step_size = dataset.get_dataset_size()
damping = get_model_damping(0, config.damping_init, config.damping_decay, 90, step_size)
lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency)
frequency=config.frequency, batch_size=config.batch_size)
if not config.label_smooth:
# define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
if target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency)
else:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
# train model
model.train(config.epoch_size, dataset, callbacks=cb)