!4414 THOR optimizer for GPU training

Merge pull request !4414 from wangmin0104/master
This commit is contained in:
mindspore-ci-bot 2020-08-14 09:13:50 +08:00 committed by Gitee
commit 1002ee4887
17 changed files with 1104 additions and 1103 deletions

View File

@ -24,22 +24,24 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
. .
├── resnet_thor ├── resnet_thor
├── README.md ├── README.md
├──scripts
├── run_distribute_train.sh # launch distributed training for Ascend
└── run_eval.sh # launch infering for Ascend
├── run_distribute_train_gpu.sh # launch distributed training for GPU
└── run_eval_gpu.sh # launch infering for GPU
├──src ├──src
├── crossentropy.py # CrossEntropy loss function ├── crossentropy.py # CrossEntropy loss function
├── config.py # parameter configuration ├── config.py # parameter configuration
├── resnet50.py # resnet50 backbone
├── dataset_helper.py # dataset help for minddata dataset ├── dataset_helper.py # dataset help for minddata dataset
├── grad_reducer_thor.py # grad reducer for thor ├── grad_reducer_thor.py # grad reducer for thor
├── model_thor.py # model ├── model_thor.py # model for train
├── resnet_thor.py # resnet50_thor backone ├── resnet_thor.py # resnet50_thor backone
├── thor.py # thor ├── thor.py # thor optimizer
├── thor_layer.py # thor layer ├── thor_layer.py # thor layer
└── dataset_imagenet.py # data preprocessing └── dataset.py # data preprocessing
├── scripts
├── run_distribute_train.sh # launch distributed training(8 pcs)
└── run_eval.sh # launch infering
├── eval.py # infer script ├── eval.py # infer script
└── train.py # train script └── train.py # train script
``` ```
@ -48,26 +50,30 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
Parameters for both training and inference can be set in config.py. Parameters for both training and inference can be set in config.py.
``` ```
"class_num": 1000, # dataset class number "class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor "batch_size": 32, # batch size of input tensor
"loss_scale": 128, # loss scale "loss_scale": 128, # loss scale
"momentum": 0.9, # momentum of THOR optimizer "momentum": 0.9, # momentum of THOR optimizer
"weight_decay": 5e-4, # weight decay "weight_decay": 5e-4, # weight decay
"epoch_size": 45, # only valid for taining, which is always 1 for inference "epoch_size": 45, # only valid for taining, which is always 1 for inference
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not "save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch "save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the checkpoint will be saved every epoch
"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint "keep_checkpoint_max": 15, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"label_smooth": True, # label smooth "label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor "label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0.045, # learning rate init value
"lr_decay": 6, # learning rate decay rate value
"lr_end_epoch": 70, # learning rate end epoch value
"damping_init": 0.03, # damping init value for Fisher information matrix
"damping_decay": 0.87, # damping decay rate
"frequency": 834, # the step interval to update second-order information matrix "frequency": 834, # the step interval to update second-order information matrix
``` ```
## Running the example ## Running the example
### 1 Running on Ascend 910
### Train ### Train
#### Usage #### Usage
@ -82,10 +88,10 @@ Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM]
```bash ```bash
# distributed training example(8 pcs) # distributed training example(8 pcs)
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc 8
``` ```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training_ascend.html).
#### Result #### Result
@ -126,3 +132,35 @@ Inference result will be stored in the example path, whose folder name is "infer
``` ```
result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt
``` ```
### 2 Running on GPU
### Train
```
# distributed training example
sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]
```
#### Result
```
# distribute training result(8 pcs)
epoch: 1 step: 5004, loss is 4.3069
epoch: 2 step: 5004, loss is 3.5695
epoch: 3 step: 5004, loss is 3.5893
epoch: 4 step: 5004, loss is 3.1987
epoch: 5 step: 5004, loss is 3.3526
......
epoch: 40 step: 5004, loss is 1.9482
epoch: 41 step: 5004, loss is 1.8950
epoch: 42 step: 5004, loss is 1.9023
......
```
### Infer
```
# infer example
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Result
```
result: {'acc': 0.760143245838668} ckpt_0/resnet-40_5004.ckpt
```

View File

@ -12,51 +12,64 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """train resnet."""
eval.
"""
import os import os
import random
import argparse import argparse
import numpy as np
from mindspore import context from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset_imagenet import create_dataset
from src.config import config
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from src.resnet50 import resnet50 from src.config import config
from src.dataset import create_dataset
from src.resnet_thor import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args() args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) random.seed(1)
np.random.seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) de.config.set_seed(1)
context.set_context(device_id=device_id)
if __name__ == '__main__': if __name__ == '__main__':
target = args_opt.device_target
net = resnet50(class_num=config.class_num) # init context
if not config.label_smooth: context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
config.label_smooth_factor = 0.0 if target != "GPU":
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if args_opt.do_eval: # create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
step_size = dataset.get_dataset_size() target=target)
if args_opt.checkpoint_path: # define net
net = resnet(class_num=config.class_num)
net.add_flags_recursive(thor=False)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
keys = list(param_dict.keys())
for key in keys:
if "damping" in key:
param_dict.pop(key)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'}) # define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset) res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path) print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -52,6 +52,6 @@ do
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
cd .. cd ..
done done

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
ulimit -u unlimited
export DEVICE_NUM=$2
export RANK_SIZE=$2
rm -rf ./train_parallel
mkdir ./train_parallel
cp ../*.py ./train_parallel
cp *.sh ./train_parallel
cp -r ../src ./train_parallel
cd ./train_parallel || exit
mpirun -n $RANK_SIZE \
python train.py --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &

13
model_zoo/official/cv/resnet_thor/scripts/run_eval.sh Normal file → Executable file
View File

@ -20,6 +20,7 @@ then
exit 1 exit 1
fi fi
get_real_path(){ get_real_path(){
if [ "${1:0:1}" == "/" ]; then if [ "${1:0:1}" == "/" ]; then
echo "$1" echo "$1"
@ -44,9 +45,6 @@ then
exit 1 exit 1
fi fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=0 export DEVICE_ID=0
@ -58,10 +56,11 @@ then
rm -rf ./eval rm -rf ./eval
fi fi
mkdir ./eval mkdir ./eval
cp *.py ./eval cp ../*.py ./eval
cp -r ./src ./eval cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start infering for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd .. cd ..

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
cd ..

View File

@ -17,21 +17,46 @@ network config setting, will be used in train.py and eval.py
""" """
from easydict import EasyDict as ed from easydict import EasyDict as ed
# config for resnet50, imagenet2012, Ascend 910
config = ed({ config = ed({
"class_num": 1000, "class_num": 1001,
"batch_size": 32, "batch_size": 32,
"loss_scale": 128, "loss_scale": 128,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 5e-4, "weight_decay": 5e-4,
"epoch_size": 45, "epoch_size": 45,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_steps": 5004, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20, "keep_checkpoint_max": 15,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"label_smooth": 1, "use_label_smooth": True,
"label_smooth_factor": 0.1, "label_smooth_factor": 0.1,
"frequency": 834 "lr_init": 0.045,
"lr_decay": 6,
"lr_end_epoch": 70,
"damping_init": 0.03,
"damping_decay": 0.87,
"frequency": 834,
})
# config for resnet50, imagenet2012, GPU
config_gpu = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.04,
"lr_decay": 5,
"lr_end_epoch": 58,
"damping_init": 0.02,
"damping_decay": 0.87,
"frequency": 834,
}) })

View File

@ -28,13 +28,10 @@ class CrossEntropy(_Loss):
self.onehot = P.OneHot() self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
# self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits() self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False) self.mean = P.ReduceMean(False)
def construct(self, logit, label): def construct(self, logit, label):
# one_hot_label = self.onehot(self.cast(label, mstype.int32),
# F.shape(logit)[1], self.on_value, self.off_value)、
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label) loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0) loss = self.mean(loss, 0)

View File

@ -16,30 +16,36 @@
create train or eval dataset. create train or eval dataset.
""" """
import os import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.c_transforms as V_C from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
""" """
create a train or eval dataset create a train or eval imagenet2012 dataset for resnet50
Args: Args:
dataset_path(string): the path of dataset. dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval. do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1 repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32 batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns: Returns:
dataset dataset
""" """
device_num = int(os.getenv("RANK_SIZE")) if target == "Ascend":
rank_id = int(os.getenv("RANK_ID")) device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1: if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id) num_shards=device_num, shard_id=rank_id)
@ -47,29 +53,28 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
image_size = 224 image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255] std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train: if do_train:
transform_img = [ trans = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5), C.RandomHorizontalFlip(prob=0.5),
V_C.Normalize(mean=mean, std=std), C.Normalize(mean=mean, std=std),
V_C.HWC2CHW() C.HWC2CHW()
] ]
else: else:
transform_img = [ trans = [
V_C.Decode(), C.Decode(),
V_C.Resize((256, 256)), C.Resize(256),
V_C.CenterCrop(image_size), C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std), C.Normalize(mean=mean, std=std),
V_C.HWC2CHW() C.HWC2CHW()
] ]
# type_cast_op = C2.TypeCast(mstype.float16)
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8) ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply shuffle operations
# ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
@ -78,3 +83,18 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)
return ds return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -13,34 +13,47 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dataset help for minddata dataset""" """Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool import math
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode import os
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes from mindspore._checkparam import check_bool, check_int
from mindspore.train.parallel_utils import ParallelMode from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full
def _send_data(dataset): def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue.""" """Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'): if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__ exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send() exec_dataset.send(epoch_num)
dataset.__has_sent__ = True dataset.__has_sent__ = True
def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send(epoch_num)
class DatasetHelper: class DatasetHelper:
""" """
Help function to use the Minddata dataset. Help function to use the MindData dataset.
According to different context, change the iter of dataset, to use the same for loop in different context. According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts.
Note: Note:
The iter of DatasetHelper will give one epoch data. The iteration of DatasetHelper will provide one epoch data.
Args: Args:
dataset (DataSet): The dataset. dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
Default: True. sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.
Examples: Examples:
>>> dataset_helper = DatasetHelper(dataset) >>> dataset_helper = DatasetHelper(dataset)
@ -48,81 +61,116 @@ class DatasetHelper:
>>> outputs = network(*inputs) >>> outputs = network(*inputs)
""" """
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, sink_size, epoch_num)
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
def __iter__(self): def __iter__(self):
return self.iter.__iter__() return self.iter.__iter__()
# A temp solution for loop sink. Delete later # A temp solution for loop sink. Delete later
def types_shapes(self): def types_shapes(self):
"""Get the types and shapes from dataset on current config.""" """Get the types and shapes from dataset on the current configuration."""
return self.iter.types_shapes() return self.iter.types_shapes()
def loop_size(self): def sink_size(self):
"""Get loop_size for every iteration.""" """Get sink_size for each iteration."""
return self.iter.loop_size return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter: class _DatasetIter:
"""Base iter for dataset help""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = 1
def __init__(self, dataset): if not hasattr(dataset, '__TRANSFER_DATASET__'):
self.loop_size = 1 if hasattr(dataset, '__loop_size__'):
if not hasattr(dataset, '__ME_INITED__'): self.sink_size = dataset.__loop_size__
if not hasattr(dataset, '__loop_size__'): dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset) _send_data(dataset, epoch_num)
else: else:
_send_data(dataset) _send_data_no_flag(dataset, epoch_num)
self.ind = 0 self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset = dataset self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
def __iter__(self): def __iter__(self):
self.ind = 0 self.index = 0
return self return self
def __next__(self): def __next__(self):
if self.ind >= self.loop_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
self.ind += 1 self.index += 1
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset): def get_sink_count(self, dataset):
loop_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
if dataset.get_dataset_size() % loop_size != 0: if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.') f'sink_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size() / loop_size) sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return loop_count return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)""" """Iter for context (device_target=Ascend)"""
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
def __init__(self, dataset, iter_first_order): super().__init__(dataset, sink_size, epoch_num)
super(_DatasetIterMSLoopSink, self).__init__(dataset) sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ + iter_first_order loop_size = dataset.__loop_size__ + iter_first_order
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number f'sink_size {loop_size} are not matched.')
# times the batch dimension of tensors for run. Now only support LoopSink. sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num() device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
@ -130,3 +178,16 @@ class _DatasetIterMSLoopSink(_DatasetIter):
return tuple() return tuple()
self.op = op self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

View File

@ -174,10 +174,6 @@ class DistributedGradReducerThor(Cell):
datatypes = self.hyper_map(F.partial(_get_datatype), grads) datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean:
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad return new_grad

View File

@ -14,27 +14,19 @@
# ============================================================================ # ============================================================================
"""Model.""" """Model."""
import numpy as np import math
from mindspore.train.callback import RunContext
from mindspore import context from mindspore import context
from mindspore import log as logger
from mindspore import nn from mindspore import nn
from mindspore._c_expression import init_exec_dataset
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor
from mindspore.nn.metrics import Loss
from mindspore.nn.metrics import get_metrics
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.train._utils import _to_full_tensor
from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full
from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset
from src.dataset_helper import DatasetHelper from src.dataset_helper import DatasetHelper
def _convert_type(types): def _convert_type(types):
""" """
Convert from numpy type to tensor type. Convert from numpy type to tensor type.
@ -76,194 +68,52 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
need_run=False) need_run=False)
class Model: class Model_Thor(Model):
""" """
High-Level API for Training or Testing. High-Level API for Training or Testing.
`Model` groups layers into an object with training and inference features. `Model` groups layers into an object with training and inference features.
Args: Args:
network (Cell): The training or testing network. network (Cell): A training or testing network.
loss_fn (Cell): Objective function, if loss_fn is None, the loss_fn (Cell): Objective function, if loss_fn is None, the
network should contain the logic of loss and grads calculation, and the logic network should contain the logic of loss and grads calculation, and the logic
of parallel if needed. Default: None. of parallel if needed. Default: None.
optimizer (Cell): Optimizer for updating the weights. Default: None. optimizer (Cell): Optimizer for updating the weights. Default: None.
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
training and testing. eg: {'accuracy', 'recall'}. Default: None. training and testing. eg: {'accuracy', 'recall'}. Default: None.
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
`eval_network`. Default: None. `eval_network`. Default: None.
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
elements, representing the positions of loss value, predict value and label, the loss elements, including the positions of loss value, predicted value and label. The loss
value would be passed to `Loss` metric, predict value and label would be passed to other value would be passed to the `Loss` metric, the predicted value and label would be passed
metric. Default: None. to other metric. Default: None.
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
precision training. Supports [O0, O2]. Default: "O0". precision training. Supports [O0, O2, O3]. Default: "O0".
- O0: Do not change. - O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else O2 is recommended on GPU, O3 is recommended on Ascend.
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
scale the loss by LossScaleManager. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value. e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
will be overwritten. Default: True.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> dataset = get_dataset()
>>> model.train(2, dataset)
""" """
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): eval_indexes=None, amp_level="O0", frequency=834, **kwargs):
self._network = network super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
self._loss_fn = loss_fn eval_indexes, amp_level, **kwargs)
self._optimizer = optimizer
self._loss_scale_manager = None
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
self._amp_level = amp_level
self._process_amp_args(kwargs)
self._parallel_mode = _get_parallel_mode()
self._device_number = _get_device_num()
self._global_rank = _get_global_rank()
self._parameter_broadcast = _get_parameter_broadcast()
self._frequency = frequency self._frequency = frequency
self._stop_epoch = stop_epoch
self._train_network = self._build_train_network() self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()
def _process_amp_args(self, kwargs): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
if self._amp_level == "O0": epoch_num=1, iter_first_order=1):
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
raise ValueError(f"Unsupport arg '{arg}'")
def _build_train_network(self):
"""Build train network"""
network = self._network
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
def _build_eval_network(self, metrics, eval_network, eval_indexes):
"""Build the network for evaluation."""
self._metric_fns = get_metrics(metrics)
if not self._metric_fns:
return
if eval_network is not None:
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
must be three. But got {}".format(eval_indexes))
self._eval_network = eval_network
self._eval_indexes = eval_indexes
else:
if self._loss_fn is None:
raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._eval_network.set_auto_parallel()
def _build_predict_network(self):
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
self._predict_network.set_auto_parallel()
def _clear_metrics(self):
"""Clear metrics local values."""
for metric in self._metric_fns.values():
metric.clear()
def _update_metrics(self, outputs):
"""Update metrics local values."""
if not isinstance(outputs, tuple):
raise ValueError("The `outputs` is not tuple.")
if self._eval_indexes is not None and len(outputs) < 3:
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
but got {}".format(len(outputs)))
for metric in self._metric_fns.values():
if self._eval_indexes is None:
metric.update(*outputs)
else:
if isinstance(metric, Loss):
metric.update(outputs[self._eval_indexes[0]])
else:
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
def _get_metrics(self):
"""Get metrics local values."""
metrics = dict()
for key, value in self._metric_fns.items():
metrics[key] = value.eval()
return metrics
def _get_scaling_sens(self):
"""get the scaling sens"""
scaling_sens = 1
if self._loss_scale_manager is not None:
scaling_sens = self._loss_scale_manager.get_loss_scale()
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
scaling_sens /= self._device_number
return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False need_wrap = False
if dataset_sink_mode: if dataset_sink_mode:
@ -275,7 +125,7 @@ class Model:
if not is_train: if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
# remove later to deal with loop sink # remove later to deal with loop sink
if need_wrap: if need_wrap:
@ -283,133 +133,31 @@ class Model:
network.set_train(is_train) network.set_train(is_train)
network.phase = phase network.phase = phase
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return dataset_helper, network return dataset_helper, network
def init(self, train_dataset=None, valid_dataset=None): def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Initializes compute graphs and data graphs with sink mode.
Note:
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
Args:
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
initialized. Default: None.
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
be initialized, and `metrics` in `Model` can not be None. Default: None.
Examples:
>>> train_dataset = get_train_dataset()
>>> valid_dataset = get_valid_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
>>> model.init(train_dataset, valid_dataset)
>>> model.train(2, train_dataset)
>>> model.eval(valid_dataset)
"""
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
if not train_dataset and not valid_dataset:
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
_device_number_check(self._parallel_mode, self._device_number)
if train_dataset:
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train_network.set_train()
self._train_network.phase = 'train'
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True)
self._train_network = train_network
for inputs in train_dataset_helper:
self._train_network.compile(*inputs)
break
if valid_dataset:
if not self._metric_fns:
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
for inputs in valid_dataset_helper:
self._eval_network.compile(*inputs)
break
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
"""
Training.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
returned and passed to the network. Otherwise, a tuple (data, label) will
be returned, and the data and label are passed to the network and loss
function respectively.
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
"""
epoch = check_int_positive(epoch)
self._train_network.set_train()
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
# build callback list
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
cb_params.batch_num = train_dataset.get_dataset_size()
cb_params.mode = "train"
cb_params.loss_fn = self._loss_fn
cb_params.optimizer = self._optimizer
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = callbacks
with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE:
logger.warning("The pynative mode cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
""" """
Training process. The data would be passed to network through dataset channel. Training process. The data would be passed to network through dataset channel.
Args: Args:
epoch (int): Total number of iterations on the data. epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss be returned. The data and label would be passed to the network and loss
function respectively. function respectively.
list_callback (Callback): Executor of callback list. Default: None. list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data in each sink. Default: -1.
""" """
if sink_size == -1:
epoch_num = epoch
else:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
iter_first_order = self._frequency - 1 iter_first_order = self._frequency - 1
iter_second_order = 1 iter_second_order = 1
train_dataset.__loop_size__ = iter_second_order train_dataset.__loop_size__ = iter_second_order
@ -418,33 +166,66 @@ class Model:
phase='train', phase='train',
dataset=train_dataset, dataset=train_dataset,
dataset_sink_mode=True, dataset_sink_mode=True,
sink_size=sink_size,
epoch_num=epoch_num,
iter_first_order=iter_first_order) iter_first_order=iter_first_order)
self._train_network = train_network self._train_network = train_network
cb_params.train_network = self._train_network cb_params.train_network = self._train_network
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
list_callback.begin(run_context) list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep # used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False should_stop = False
has_do_dataset_init = False
switch_branch_one = True switch_branch_one = True
index_first_order = 0
train_network_init_flag = True
has_do_dataset_init = False
for i in range(epoch): for i in range(epoch):
cb_params.cur_epoch_num = i + 1 cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context) list_callback.epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
if context.get_context("device_target") == "GPU":
if switch_branch_one: if switch_branch_one:
cb_params.cur_step_num += loop_size cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
else:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
index_first_order += 1
if index_first_order == iter_first_order:
index_first_order = 0
switch_branch_one = not switch_branch_one
list_callback.step_end(run_context)
else:
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True) self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0' self._train_network.phase = 'train0'
else: else:
cb_params.cur_step_num += iter_first_order cb_params.cur_step_num += iter_first_order
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False) self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1' self._train_network.phase = 'train1'
if not has_do_dataset_init: if not has_do_dataset_init:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
@ -458,268 +239,9 @@ class Model:
should_stop = should_stop or run_context.get_stop_requested() should_stop = should_stop or run_context.get_stop_requested()
if should_stop: if should_stop:
break break
dataset_helper.stop_send()
list_callback.end(run_context) list_callback.end(run_context)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
Training process. The data would be passed to network directly.
Args: __all__ = ["Model_Thor"]
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
"""
dataset_helper, _ = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=False)
cb_params.cur_step_num = 0
run_context = RunContext(cb_params)
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
for next_element in dataset_helper:
len_element = len(next_element)
if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
_, overflow, _ = outputs
overflow = np.all(overflow.asnumpy())
self._loss_scale_manager.update_loss_scale(overflow)
list_callback.step_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
train_dataset.reset()
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
"""
Training API where the iteration is controlled by python front-end.
When setting pynative mode, the training process will be performed with dataset not sink.
Note:
CPU is not supported when dataset_sink_mode is true.
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
operation in dataset processing. Otherwise, errors could occur since the amount of data
is not the amount training requires.
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned, and the data and label are passed to the network and loss
function respectively.
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
Examples:
>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
repeat_count = train_dataset.get_repeat_count()
if epoch != repeat_count and dataset_sink_mode is True:
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network through dataset channel.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
"""
run_context = RunContext(cb_params)
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=True)
self._eval_network = eval_network
cb_params.eval_network = self._eval_network
list_callback.begin(run_context)
for inputs in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
self._update_metrics(outputs)
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
return metrics
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
Evaluation. The data would be passed to network directly.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
"""
run_context = RunContext(cb_params)
list_callback.begin(run_context)
dataset_helper, _ = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
dataset=valid_dataset,
dataset_sink_mode=False)
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
self._update_metrics(outputs)
metrics = self._get_metrics()
cb_params.metrics = metrics
list_callback.end(run_context)
return metrics
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
"""
Evaluation API where the iteration is controlled by python front-end.
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
Note:
CPU is not supported when dataset_sink_mode is true.
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.
callbacks (list): List of callback object. Callbacks which should be excuted
while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Returns:
Dict, returns the loss value & metrics values for the model in test mode.
Examples:
>>> dataset = get_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
>>> model.eval(dataset)
"""
check_bool(dataset_sink_mode)
_device_number_check(self._parallel_mode, self._device_number)
if not self._metric_fns:
raise ValueError("metric fn can not be None or empty.")
cb_params = _InternalCallbackParam()
cb_params.eval_network = self._eval_network
cb_params.valid_dataset = valid_dataset
cb_params.batch_num = valid_dataset.get_dataset_size()
cb_params.mode = "eval"
cb_params.cur_step_num = 0
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
self._clear_metrics()
with _CallbackManager(callbacks) as list_callback:
if dataset_sink_mode:
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
return self._eval_process(valid_dataset, list_callback, cb_params)
def predict(self, *predict_data):
"""
Generates output predictions for the input samples.
Data could be single tensor, or list of tensor, tuple of tensor.
Note:
Batch data should be put together in one tensor.
Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
Returns:
Tensor, array(s) of predictions.
Examples:
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.predict(input_data)
"""
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
result = self._predict_network(*predict_data)
check_output_data(result)
return result
__all__ = ["Model"]

View File

@ -1,262 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
_bn(out_channel)])
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3])
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

View File

@ -18,8 +18,9 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context
from src.thor_layer import Conv2d_Thor, Dense_Thor from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU
def calculate_gain(nonlinearity, param=None): def calculate_gain(nonlinearity, param=None):
@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return np.random.normal(0, std, size=inputs_shape).astype(np.float32) return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode) fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a) gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu')
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 3, 3) weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 1, 1) weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 7, 7) weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
layer = Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _bn(channel): def _bn(channel):
@ -120,14 +144,21 @@ def _bn(channel):
def _bn_last(channel): def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, damping, loss_scale, frequency): def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32):
weight_shape = (out_channel, in_channel) weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, if context.get_context('device_target') == "Ascend":
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
else:
layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
return layer
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):
@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell):
stride=1, stride=1,
damping=0.03, damping=0.03,
loss_scale=1, loss_scale=1,
frequency=278): frequency=278,
batch_size=32):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale, self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(channel) self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale, self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn2 = _bn(channel) self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale, self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn3 = _bn_last(out_channel) self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell):
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale, damping=damping, loss_scale=loss_scale,
frequency=frequency), frequency=frequency,
batch_size=batch_size),
_bn(out_channel)]) _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.TensorAdd()
@ -239,16 +272,19 @@ class ResNet(nn.Cell):
num_classes, num_classes,
damping, damping,
loss_scale, loss_scale,
frequency): frequency,
batch_size):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency) self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(64) self.bn1 = _bn(64)
self.relu = P.ReLU() self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) # self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block, self.layer1 = self._make_layer(block,
layer_nums[0], layer_nums[0],
@ -257,7 +293,8 @@ class ResNet(nn.Cell):
stride=strides[0], stride=strides[0],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer2 = self._make_layer(block, self.layer2 = self._make_layer(block,
layer_nums[1], layer_nums[1],
in_channel=in_channels[1], in_channel=in_channels[1],
@ -265,14 +302,16 @@ class ResNet(nn.Cell):
stride=strides[1], stride=strides[1],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer3 = self._make_layer(block, self.layer3 = self._make_layer(block,
layer_nums[2], layer_nums[2],
in_channel=in_channels[2], in_channel=in_channels[2],
out_channel=out_channels[2], out_channel=out_channels[2],
stride=strides[2], damping=damping, stride=strides[2], damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer4 = self._make_layer(block, self.layer4 = self._make_layer(block,
layer_nums[3], layer_nums[3],
in_channel=in_channels[3], in_channel=in_channels[3],
@ -280,14 +319,16 @@ class ResNet(nn.Cell):
stride=strides[3], stride=strides[3],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency) self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency): damping, loss_scale, frequency, batch_size):
""" """
Make stage network of ResNet. Make stage network of ResNet.
@ -307,12 +348,14 @@ class ResNet(nn.Cell):
layers = [] layers = []
resnet_block = block(in_channel, out_channel, stride=stride, resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block) layers.append(resnet_block)
for _ in range(1, layer_num): for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block) layers.append(resnet_block)
return nn.SequentialCell(layers) return nn.SequentialCell(layers)
@ -321,7 +364,7 @@ class ResNet(nn.Cell):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
c1, _ = self.maxpool(x) c1 = self.maxpool(x)
c2 = self.layer1(c1) c2 = self.layer1(c1)
c3 = self.layer2(c2) c3 = self.layer2(c2)
@ -335,7 +378,7 @@ class ResNet(nn.Cell):
return out return out
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
""" """
Get ResNet50 neural network. Get ResNet50 neural network.
@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
class_num, class_num,
damping, damping,
loss_scale, loss_scale,
frequency) frequency,
batch_size)

View File

@ -12,27 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""momentum""" """THOR"""
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
op_add = P.AddN() op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay") apply_decay = C.MultitypeFuncGraph("apply_decay")
@ -46,6 +39,119 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
return gradient return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
class THOR_GPU(Optimizer):
"""
THOR
"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
self.feature_map_new = [x ** 0.5 for x in self.feature_map]
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.assign = P.Assign()
self.mul = P.Mul()
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=128)
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
new_grads = ()
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_A = self.mul(matrix_A, self.feature_map_new[i])
matrix_G = self.mul(matrix_G, self.feature_map_new[i])
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = matrix_A_allreduce[i]
matrix_G = matrix_G_allreduce[i]
g = self.update_gradient(matrix_G, g, matrix_A)
fake_A = self.assign(self.matrix_A[i], matrix_A)
fake_G = self.assign(self.matrix_G[i], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
else:
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
g = self.update_gradient(matrix_G, g, matrix_A)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class THOR(Optimizer): class THOR(Optimizer):
"""THOR""" """THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
@ -195,5 +301,5 @@ class THOR(Optimizer):
params, gradients) params, gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success return success

View File

@ -15,7 +15,6 @@
"""thor_layer""" """thor_layer"""
import numpy as np import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, twice, check_int_positive from mindspore._checkparam import check_bool, twice, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
@ -25,8 +24,10 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P from mindspore.ops import operations as P
C0 = 16 C0 = 16
def caculate_device_shape(matrix_dim, channel, is_A): def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0) ll = (0)
if is_A: if is_A:
@ -37,6 +38,31 @@ def caculate_device_shape(matrix_dim, channel, is_A):
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll return ll
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
split_dimA = split_dim
split_dimG = split_dim
if matrix_A_dim % split_dim == 0:
batch_w = matrix_A_dim // split_dim
else:
if matrix_A_dim < split_dim:
batch_w = 1
split_dimA = matrix_A_dim
else:
batch_w = matrix_A_dim // split_dim + 1
if matrix_G_dim % split_dim == 0:
batch_h = matrix_G_dim // split_dim
else:
if matrix_G_dim < split_dim:
batch_h = 1
split_dimG = matrix_G_dim
else:
batch_h = matrix_G_dim // split_dim + 1
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
matrix_G_shape = (batch_h, split_dimG, split_dimG)
return matrix_A_shape, matrix_G_shape
class _Conv(Cell): class _Conv(Cell):
r"""Applies a N-D convolution over an input signal composed of several input r"""Applies a N-D convolution over an input signal composed of several input
planes. planes.
@ -97,6 +123,286 @@ class _Conv(Cell):
raise NotImplementedError raise NotImplementedError
class Conv2d_Thor_GPU(_Conv):
"""Conv2d_Thor"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
data_format='NCHW',
has_bias=False,
weight_init='normal',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
bias_init='zeros'):
self.thor = True
self.hw = kernel_size * kernel_size
kernel_size = twice(kernel_size)
super(Conv2d_Thor_GPU, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group
)
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.matrix_A_dim, self.matrix_G_dim, split_dim)
self.matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32),
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32),
name='matrix_A_inv', requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
self.matmul = P.MatMul(transpose_b=True)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.getG = P.InsertGradientOf(self.save_gradient)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.transpose = P.Transpose()
self.cast = P.Cast()
self.gather = P.GatherV2()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.sqrt = P.Sqrt()
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32)
self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32)
self.cholesky = P.Cholesky(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout = self.reduce_mean(dout, 0)
dout_shape = self.shape(dout)
dout = self.reshape(dout, (dout_shape[0], -1))
dout_shape = self.shape(dout)
normalizer = dout_shape[1]
dout = self.cast(dout, mstype.float32)
matrix_G = self.matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
matrix_A = self.reshape(matrix_A, (matrix_A_shape[0]*matrix_A_shape[1]*matrix_A_shape[2],
matrix_A_shape[3], -1))
matrix_A = self.reduce_mean(matrix_A, 1)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[1]
matrix_A = self.cast(matrix_A, mstype.float32)
matrix_A = self.matmul(matrix_A, matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 1.0 / normalizer)
damping = self.sqrt(damping)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
out = self.conv2d(x, self.weight)
out = self.getG(out)
else:
out = self.conv2d(x, self.weight)
return out
def extra_repr(self):
"""extra_repr"""
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, data_format={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.data_format,
self.has_bias,
self.weight,
self.bias)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s
class Dense_Thor_GPU(Cell):
"""Dense_Thor"""
@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=278,
batch_size=32,
has_bias=True,
activation=None):
super(Dense_Thor_GPU, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
split_dim = 128
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.in_channels, self.out_channels, split_dim)
self.matrix_A_inv = Parameter(Tensor(np.zeros(matrix_A_shape).astype(np.float32)),
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros(matrix_G_shape).astype(np.float32)),
name="matrix_G_inv", requires_grad=False)
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.mul = P.Mul()
self.cube_matmul = P.MatMul(transpose_a=True)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.batch_size = Tensor(batch_size, mstype.float16)
self.getG = P.InsertGradientOf(self.save_gradient)
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
self.cast = P.Cast()
self.gather = P.GatherV2()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.add = P.TensorAdd()
self.sqrt = P.Sqrt()
self.cholesky = P.Cholesky(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, self.batch_size)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
dout = self.cast(dout, mstype.float32)
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.sqrt(damping_step)
matrix_G = matrix_G + damping * self.dampingG
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
self.matrix_G_inv = matrix_G
return out
def construct(self, x):
"""construct"""
if self.thor:
inputs = self.cast(x, mstype.float32)
inputs = self.cube_matmul(inputs, inputs)
inputs_shape = self.shape(inputs)
normalizer = inputs_shape[0]
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
matrix_A = matrix_A + damping * self.dampingA
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = self.broadcast_to(matrix_A)
self.matrix_A_inv = matrix_A
output = self.matmul(x, self.weight)
output = self.getG(output)
else:
output = self.matmul(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output
def extend_repr(self):
"""extend_repr"""
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias={}'.format(self.bias)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
class Conv2d_Thor(_Conv): class Conv2d_Thor(_Conv):
"""Conv2d_Thor""" """Conv2d_Thor"""
def __init__(self, def __init__(self,
@ -114,6 +420,7 @@ class Conv2d_Thor(_Conv):
damping=0.03, damping=0.03,
loss_scale=1, loss_scale=1,
frequency=278, frequency=278,
batch_size=32,
bias_init='zeros'): bias_init='zeros'):
self.thor = True self.thor = True
ksizes = (1, kernel_size, kernel_size, 1) ksizes = (1, kernel_size, kernel_size, 1)
@ -143,7 +450,7 @@ class Conv2d_Thor(_Conv):
dilation=self.dilation, dilation=self.dilation,
group=self.group group=self.group
) )
self.batch_size = batch_size
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True) self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine() self.matrix_combine = P.CusMatrixCombine()
@ -228,7 +535,7 @@ class Conv2d_Thor(_Conv):
normalizer = dout_shape[0] normalizer = dout_shape[0]
matrix_G = self.cube_matmul(dout, dout) matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32) normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer) matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0) damping_step = self.gather(self.damping, self.cov_step, 0)
self.cov_step = self.cov_step + self.freq self.cov_step = self.cov_step + self.freq
@ -261,7 +568,7 @@ class Conv2d_Thor(_Conv):
matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0)) matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0))
matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels)) matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim)) matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
normalizer = self.cast(normalizer, ms.float32) normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(matrix_A, 1.0 / normalizer) matrix_A = self.mul(matrix_A, 1.0 / normalizer)
if self.padA_flag: if self.padA_flag:
matrix_A = self.padA(matrix_A) matrix_A = self.padA(matrix_A)
@ -330,6 +637,7 @@ class Dense_Thor(Cell):
damping=0.03, damping=0.03,
loss_scale=1, loss_scale=1,
frequency=278, frequency=278,
batch_size=32,
has_bias=True, has_bias=True,
activation=None): activation=None):
super(Dense_Thor, self).__init__() super(Dense_Thor, self).__init__()
@ -337,6 +645,7 @@ class Dense_Thor(Cell):
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
self.thor = True self.thor = True
self.batch_size = batch_size
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels: weight_init.shape[1] != in_channels:
@ -376,8 +685,8 @@ class Dense_Thor(Cell):
self.damping = Tensor(damping) self.damping = Tensor(damping)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16) self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.vector_matmul = P.CusBatchMatMul() self.vector_matmul = P.CusBatchMatMul()
self.pad = P.Pad(((0, 24), (0, 24))) self.pad = P.Pad(((0, 23), (0, 23)))
self.pad1 = P.Pad(((0, 8), (0, 8))) self.pad1 = P.Pad(((0, 7), (0, 7)))
self.slice = P.Slice() self.slice = P.Slice()
self.gather = P.GatherV2() self.gather = P.GatherV2()
self.assignadd = P.AssignAdd() self.assignadd = P.AssignAdd()
@ -385,7 +694,7 @@ class Dense_Thor(Cell):
self.axis = 0 self.axis = 0
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) self.fused_abs_max1 = P.CusFusedAbsMax1([1001, 1001])
self.fused_abs_max2 = P.CusFusedAbsMax1() self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log() self.log = P.Log()
self.exp = P.Exp() self.exp = P.Exp()
@ -402,7 +711,7 @@ class Dense_Thor(Cell):
dout = self.mul(dout, 32.0) dout = self.mul(dout, 32.0)
normalizer = 32 normalizer = 32
matrix_G = self.cube_matmul(dout, dout) matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32) normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer) matrix_G = self.mul(matrix_G, 1.0 / normalizer)
matrix_G = self.pad(matrix_G) matrix_G = self.pad(matrix_G)
damping_step = self.gather(self.damping, self.cov_step, 0) damping_step = self.gather(self.damping, self.cov_step, 0)
@ -417,7 +726,7 @@ class Dense_Thor(Cell):
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max) matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv) matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000)) matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
matrix_G_inv = self.pad1(matrix_G_inv) matrix_G_inv = self.pad1(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv) matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16)) matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16))
@ -431,7 +740,7 @@ class Dense_Thor(Cell):
if self.thor: if self.thor:
inputs = self.cube_matmul(x, x) inputs = self.cube_matmul(x, x)
normalizer = 32 normalizer = 32
normalizer = self.cast(normalizer, ms.float32) normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer) matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis) damping_step = self.gather(self.damping, self.cov_step, self.axis)

View File

@ -12,44 +12,46 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train_imagenet.""" """train resnet."""
import argparse
import os import os
import random import random
import argparse
import numpy as np import numpy as np
from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.communication.management import init from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from src.model_thor import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from src.resnet_thor import resnet50 from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.thor import THOR from mindspore.communication.management import init, get_rank, get_group_size
from src.config import config
from src.crossentropy import CrossEntropy
from src.dataset_imagenet import create_dataset
random.seed(1) from src.model_thor import Model_Thor as Model
np.random.seed(1) from src.resnet_thor import resnet50
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_num', type=int, default=1, help='Device num')
args_opt = parser.parse_args() args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) if args_opt.device_target == "Ascend":
from src.thor import THOR
from src.config import config
else:
from src.thor import THOR_GPU as THOR
from src.config import config_gpu as config
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr""" """get_model_lr"""
lr_each_step = [] lr_each_step = []
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
@ -57,9 +59,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
epoch = (i + 1) / steps_per_epoch epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base lr_local = lr_init * base
if epoch >= 39: if epoch >= decay_epochs:
lr_local = lr_local * 0.5 lr_local = lr_local * 0.5
if epoch >= 40: if epoch >= decay_epochs + 1:
lr_local = lr_local * 0.5 lr_local = lr_local * 0.5
lr_each_step.append(lr_local) lr_each_step.append(lr_local)
current_step = global_step current_step = global_step
@ -76,7 +78,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
epoch = (step + 1) / steps_per_epoch epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10)) damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here) damping_each_step.append(damping_here)
current_step = global_step current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32) damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:] damping_now = damping_each_step[current_step:]
@ -84,49 +85,70 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if args_opt.run_distribute:
# Ascend target
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True) mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
init() init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
epoch_size = config.epoch_size # create dataset
damping = get_model_damping(0, 0.03, 0.87, 50, 5004) dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
# define net
step_size = dataset.get_dataset_size()
damping = get_model_damping(0, config.damping_init, config.damping_decay, 90, step_size)
lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency) frequency=config.frequency, batch_size=config.batch_size)
if not config.label_smooth: # define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train: opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale) config.weight_decay, config.loss_scale)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
if target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency)
else:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency)
# define callbacks
time_cb = TimeMonitor(data_size=step_size) time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor() loss_cb = LossMonitor()
cb = [time_cb, loss_cb] cb = [time_cb, loss_cb]
if config.save_checkpoint: if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb) # train model
model.train(config.epoch_size, dataset, callbacks=cb)