forked from mindspore-Ecosystem/mindspore
!4414 THOR optimizer for GPU training
Merge pull request !4414 from wangmin0104/master
This commit is contained in:
commit
1002ee4887
|
@ -24,22 +24,24 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
|
|||
.
|
||||
├── resnet_thor
|
||||
├── README.md
|
||||
├── src
|
||||
├──scripts
|
||||
├── run_distribute_train.sh # launch distributed training for Ascend
|
||||
└── run_eval.sh # launch infering for Ascend
|
||||
├── run_distribute_train_gpu.sh # launch distributed training for GPU
|
||||
└── run_eval_gpu.sh # launch infering for GPU
|
||||
├──src
|
||||
├── crossentropy.py # CrossEntropy loss function
|
||||
├── config.py # parameter configuration
|
||||
├── resnet50.py # resnet50 backbone
|
||||
├── dataset_helper.py # dataset help for minddata dataset
|
||||
├── grad_reducer_thor.py # grad reducer for thor
|
||||
├── model_thor.py # model
|
||||
├── model_thor.py # model for train
|
||||
├── resnet_thor.py # resnet50_thor backone
|
||||
├── thor.py # thor
|
||||
├── thor.py # thor optimizer
|
||||
├── thor_layer.py # thor layer
|
||||
└── dataset_imagenet.py # data preprocessing
|
||||
├── scripts
|
||||
├── run_distribute_train.sh # launch distributed training(8 pcs)
|
||||
└── run_eval.sh # launch infering
|
||||
└── dataset.py # data preprocessing
|
||||
├── eval.py # infer script
|
||||
└── train.py # train script
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
@ -48,26 +50,30 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
|
|||
Parameters for both training and inference can be set in config.py.
|
||||
|
||||
```
|
||||
"class_num": 1000, # dataset class number
|
||||
"class_num": 1001, # dataset class number
|
||||
"batch_size": 32, # batch size of input tensor
|
||||
"loss_scale": 128, # loss scale
|
||||
"momentum": 0.9, # momentum of THOR optimizer
|
||||
"weight_decay": 5e-4, # weight decay
|
||||
"epoch_size": 45, # only valid for taining, which is always 1 for inference
|
||||
"buffer_size": 1000, # number of queue size in data preprocessing
|
||||
"image_height": 224, # image height
|
||||
"image_width": 224, # image width
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch
|
||||
"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the checkpoint will be saved every epoch
|
||||
"keep_checkpoint_max": 15, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
|
||||
"label_smooth": True, # label smooth
|
||||
"label_smooth_factor": 0.1, # label smooth factor
|
||||
"lr_init": 0.045, # learning rate init value
|
||||
"lr_decay": 6, # learning rate decay rate value
|
||||
"lr_end_epoch": 70, # learning rate end epoch value
|
||||
"damping_init": 0.03, # damping init value for Fisher information matrix
|
||||
"damping_decay": 0.87, # damping decay rate
|
||||
"frequency": 834, # the step interval to update second-order information matrix
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### 1 Running on Ascend 910
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
@ -82,10 +88,10 @@ Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM]
|
|||
|
||||
```bash
|
||||
# distributed training example(8 pcs)
|
||||
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc
|
||||
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc 8
|
||||
```
|
||||
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training_ascend.html).
|
||||
|
||||
#### Result
|
||||
|
||||
|
@ -126,3 +132,35 @@ Inference result will be stored in the example path, whose folder name is "infer
|
|||
```
|
||||
result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt
|
||||
```
|
||||
|
||||
### 2 Running on GPU
|
||||
|
||||
### Train
|
||||
```
|
||||
# distributed training example
|
||||
sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]
|
||||
```
|
||||
#### Result
|
||||
```
|
||||
# distribute training result(8 pcs)
|
||||
epoch: 1 step: 5004, loss is 4.3069
|
||||
epoch: 2 step: 5004, loss is 3.5695
|
||||
epoch: 3 step: 5004, loss is 3.5893
|
||||
epoch: 4 step: 5004, loss is 3.1987
|
||||
epoch: 5 step: 5004, loss is 3.3526
|
||||
......
|
||||
epoch: 40 step: 5004, loss is 1.9482
|
||||
epoch: 41 step: 5004, loss is 1.8950
|
||||
epoch: 42 step: 5004, loss is 1.9023
|
||||
......
|
||||
```
|
||||
|
||||
### Infer
|
||||
```
|
||||
# infer example
|
||||
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
#### Result
|
||||
```
|
||||
result: {'acc': 0.760143245838668} ckpt_0/resnet-40_5004.ckpt
|
||||
```
|
||||
|
|
|
@ -12,51 +12,64 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
eval.
|
||||
"""
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.dataset_imagenet import create_dataset
|
||||
from src.config import config
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.resnet50 import resnet50
|
||||
from src.config import config
|
||||
from src.dataset import create_dataset
|
||||
from src.resnet_thor import resnet50 as resnet
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
|
||||
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id)
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
||||
net = resnet50(class_num=config.class_num)
|
||||
if not config.label_smooth:
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
if target != "GPU":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||
target=target)
|
||||
|
||||
# define net
|
||||
net = resnet(class_num=config.class_num)
|
||||
net.add_flags_recursive(thor=False)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
keys = list(param_dict.keys())
|
||||
for key in keys:
|
||||
if "damping" in key:
|
||||
param_dict.pop(key)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss, model
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
if args_opt.do_eval:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||
# eval model
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
||||
|
|
|
@ -52,6 +52,6 @@ do
|
|||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
|
||||
env > env.log
|
||||
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$2
|
||||
export RANK_SIZE=$2
|
||||
|
||||
rm -rf ./train_parallel
|
||||
mkdir ./train_parallel
|
||||
cp ../*.py ./train_parallel
|
||||
cp *.sh ./train_parallel
|
||||
cp -r ../src ./train_parallel
|
||||
cd ./train_parallel || exit
|
||||
|
||||
mpirun -n $RANK_SIZE \
|
||||
python train.py --run_distribute=True \
|
||||
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &
|
|
@ -20,6 +20,7 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
|
@ -44,9 +45,6 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
|
@ -58,10 +56,11 @@ then
|
|||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp *.py ./eval
|
||||
cp -r ./src ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
cd ..
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
|
||||
cd ..
|
|
@ -17,21 +17,46 @@ network config setting, will be used in train.py and eval.py
|
|||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
# config for resnet50, imagenet2012, Ascend 910
|
||||
config = ed({
|
||||
"class_num": 1000,
|
||||
"class_num": 1001,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 128,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"epoch_size": 45,
|
||||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 5004,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 15,
|
||||
"save_checkpoint_path": "./",
|
||||
"label_smooth": 1,
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"frequency": 834
|
||||
"lr_init": 0.045,
|
||||
"lr_decay": 6,
|
||||
"lr_end_epoch": 70,
|
||||
"damping_init": 0.03,
|
||||
"damping_decay": 0.87,
|
||||
"frequency": 834,
|
||||
})
|
||||
|
||||
# config for resnet50, imagenet2012, GPU
|
||||
config_gpu = ed({
|
||||
"class_num": 1001,
|
||||
"batch_size": 32,
|
||||
"loss_scale": 128,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 5e-4,
|
||||
"epoch_size": 45,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 15,
|
||||
"save_checkpoint_path": "./",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.04,
|
||||
"lr_decay": 5,
|
||||
"lr_end_epoch": 58,
|
||||
"damping_init": 0.02,
|
||||
"damping_decay": 0.87,
|
||||
"frequency": 834,
|
||||
})
|
||||
|
|
|
@ -28,13 +28,10 @@ class CrossEntropy(_Loss):
|
|||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
# self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
# one_hot_label = self.onehot(self.cast(label, mstype.int32),
|
||||
# F.shape(logit)[1], self.on_value, self.off_value)、
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
|
|
|
@ -16,30 +16,36 @@
|
|||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.c_transforms as V_C
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
@ -47,29 +53,28 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
transform_img = [
|
||||
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
V_C.RandomHorizontalFlip(prob=0.5),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
V_C.Decode(),
|
||||
V_C.Resize((256, 256)),
|
||||
V_C.CenterCrop(image_size),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
# type_cast_op = C2.TypeCast(mstype.float16)
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
|
||||
|
||||
# apply shuffle operations
|
||||
# ds = ds.shuffle(buffer_size=config.buffer_size)
|
||||
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
|
||||
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
@ -78,3 +83,18 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -13,34 +13,47 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
import math
|
||||
import os
|
||||
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from mindspore import context
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
|
||||
from mindspore.nn.wrap import GetNextSingleOp
|
||||
from mindspore.parallel._utils import _get_device_num, _need_to_full
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send()
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
Help function to use the MindData dataset.
|
||||
|
||||
According to different context, change the iter of dataset, to use the same for loop in different context.
|
||||
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
|
||||
contexts.
|
||||
|
||||
Note:
|
||||
The iter of DatasetHelper will give one epoch data.
|
||||
The iteration of DatasetHelper will provide one epoch data.
|
||||
|
||||
Args:
|
||||
dataset (DataSet): The dataset.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
|
||||
Default: True.
|
||||
dataset (DataSet): The training dataset iterator.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
|
||||
sink_size (int): Control the amount of data in each sink.
|
||||
If sink_size=-1, sink the complete dataset for each epoch.
|
||||
If sink_size>0, sink sink_size data for each epoch. Default: -1.
|
||||
epoch_num (int): Control the number of epoch data to send. Default: 1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
|
@ -48,81 +61,116 @@ class DatasetHelper:
|
|||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
|
||||
check_bool(dataset_sink_mode)
|
||||
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
iterclass = _DatasetIterMS
|
||||
self.iter = iterclass(dataset, sink_size, epoch_num)
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter.__iter__()
|
||||
|
||||
# A temp solution for loop sink. Delete later
|
||||
def types_shapes(self):
|
||||
"""Get the types and shapes from dataset on current config."""
|
||||
"""Get the types and shapes from dataset on the current configuration."""
|
||||
return self.iter.types_shapes()
|
||||
|
||||
def loop_size(self):
|
||||
"""Get loop_size for every iteration."""
|
||||
return self.iter.loop_size
|
||||
def sink_size(self):
|
||||
"""Get sink_size for each iteration."""
|
||||
return self.iter.get_sink_size()
|
||||
|
||||
def stop_send(self):
|
||||
"""Free up resources about data sink."""
|
||||
self.iter.stop_send()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset help"""
|
||||
"""Base iter for dataset helper"""
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
self.dataset = dataset
|
||||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.loop_size = 1
|
||||
if not hasattr(dataset, '__ME_INITED__'):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
||||
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
self.ind = 0
|
||||
self.index = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.ind >= self.loop_count:
|
||||
if self.index >= self.sink_count:
|
||||
raise StopIteration()
|
||||
self.ind += 1
|
||||
self.index += 1
|
||||
return self.op()
|
||||
|
||||
def types_shapes(self):
|
||||
return self.dataset_types, self.dataset_shapes
|
||||
|
||||
def get_loop_count(self, dataset):
|
||||
loop_count = 1
|
||||
def get_sink_count(self, dataset):
|
||||
sink_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__
|
||||
if dataset.get_dataset_size() % loop_size != 0:
|
||||
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
|
||||
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
|
||||
f'loop_size {loop_size} are not matched.')
|
||||
loop_count = int(dataset.get_dataset_size() / loop_size)
|
||||
return loop_count
|
||||
f'sink_size {loop_size} are not matched.')
|
||||
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
|
||||
return sink_count
|
||||
|
||||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
else:
|
||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
|
||||
if self.sink_size > 0:
|
||||
sink_size = self.sink_size
|
||||
else:
|
||||
sink_size = self.dataset.get_dataset_size()
|
||||
return sink_size
|
||||
|
||||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context (device_target=Ascend)"""
|
||||
|
||||
def __init__(self, dataset, iter_first_order):
|
||||
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
||||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
|
||||
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
|
||||
# times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
sink_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
|
||||
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
|
||||
f'sink_size {loop_size} are not matched.')
|
||||
sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2
|
||||
self.sink_count = sink_count
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
|
@ -130,3 +178,16 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
return tuple()
|
||||
|
||||
self.op = op
|
||||
|
||||
|
||||
class _DatasetIterMS(_DatasetIter):
|
||||
"""Iter for MS(enable_loop_sink=False)."""
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
if sink_size > 0:
|
||||
self.sink_count = sink_size
|
||||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__ME_INITED__
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
|
|
@ -174,10 +174,6 @@ class DistributedGradReducerThor(Cell):
|
|||
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
|
||||
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
|
||||
if self.mean:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
|
||||
else:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
|
||||
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
|
||||
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
|
||||
return new_grad
|
||||
|
|
|
@ -14,27 +14,19 @@
|
|||
# ============================================================================
|
||||
"""Model."""
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from mindspore.train.callback import RunContext
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.train import amp
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.train._utils import _to_full_tensor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.parallel._utils import _need_to_full
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
|
||||
from src.dataset_helper import DatasetHelper
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
Convert from numpy type to tensor type.
|
||||
|
@ -76,194 +68,52 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
need_run=False)
|
||||
|
||||
|
||||
class Model:
|
||||
class Model_Thor(Model):
|
||||
"""
|
||||
High-Level API for Training or Testing.
|
||||
|
||||
`Model` groups layers into an object with training and inference features.
|
||||
|
||||
Args:
|
||||
network (Cell): The training or testing network.
|
||||
network (Cell): A training or testing network.
|
||||
loss_fn (Cell): Objective function, if loss_fn is None, the
|
||||
network should contain the logic of loss and grads calculation, and the logic
|
||||
of parallel if needed. Default: None.
|
||||
optimizer (Cell): Optimizer for updating the weights. Default: None.
|
||||
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
|
||||
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
|
||||
training and testing. eg: {'accuracy', 'recall'}. Default: None.
|
||||
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
||||
`eval_network`. Default: None.
|
||||
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
|
||||
eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
|
||||
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
|
||||
elements, representing the positions of loss value, predict value and label, the loss
|
||||
value would be passed to `Loss` metric, predict value and label would be passed to other
|
||||
metric. Default: None.
|
||||
elements, including the positions of loss value, predicted value and label. The loss
|
||||
value would be passed to the `Loss` metric, the predicted value and label would be passed
|
||||
to other metric. Default: None.
|
||||
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
|
||||
precision training. Supports [O0, O2]. Default: "O0".
|
||||
precision training. Supports [O0, O2, O3]. Default: "O0".
|
||||
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
||||
O2 is recommended on GPU, O3 is recommended on Ascend.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
|
||||
scale the loss by LossScaleManager. It is a key argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
>>> dataset = get_dataset()
|
||||
>>> model.train(2, dataset)
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
|
||||
will be overwritten. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._process_amp_args(kwargs)
|
||||
self._parallel_mode = _get_parallel_mode()
|
||||
self._device_number = _get_device_num()
|
||||
self._global_rank = _get_global_rank()
|
||||
self._parameter_broadcast = _get_parameter_broadcast()
|
||||
eval_indexes=None, amp_level="O0", frequency=834, **kwargs):
|
||||
super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
|
||||
eval_indexes, amp_level, **kwargs)
|
||||
self._frequency = frequency
|
||||
self._stop_epoch = stop_epoch
|
||||
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level == "O0":
|
||||
self._keep_bn_fp32 = False
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
self._loss_scale_manager = kwargs['loss_scale_manager']
|
||||
self._loss_scale_manager_set = True
|
||||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
raise ValueError(f"Unsupport arg '{arg}'")
|
||||
|
||||
def _build_train_network(self):
|
||||
"""Build train network"""
|
||||
network = self._network
|
||||
if self._optimizer:
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
"""Build the network for evaluation."""
|
||||
self._metric_fns = get_metrics(metrics)
|
||||
if not self._metric_fns:
|
||||
return
|
||||
|
||||
if eval_network is not None:
|
||||
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
|
||||
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
|
||||
must be three. But got {}".format(eval_indexes))
|
||||
|
||||
self._eval_network = eval_network
|
||||
self._eval_indexes = eval_indexes
|
||||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
for metric in self._metric_fns.values():
|
||||
metric.clear()
|
||||
|
||||
def _update_metrics(self, outputs):
|
||||
"""Update metrics local values."""
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError("The `outputs` is not tuple.")
|
||||
|
||||
if self._eval_indexes is not None and len(outputs) < 3:
|
||||
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
|
||||
but got {}".format(len(outputs)))
|
||||
|
||||
for metric in self._metric_fns.values():
|
||||
if self._eval_indexes is None:
|
||||
metric.update(*outputs)
|
||||
else:
|
||||
if isinstance(metric, Loss):
|
||||
metric.update(outputs[self._eval_indexes[0]])
|
||||
else:
|
||||
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
|
||||
|
||||
def _get_metrics(self):
|
||||
"""Get metrics local values."""
|
||||
metrics = dict()
|
||||
for key, value in self._metric_fns.items():
|
||||
metrics[key] = value.eval()
|
||||
return metrics
|
||||
|
||||
def _get_scaling_sens(self):
|
||||
"""get the scaling sens"""
|
||||
scaling_sens = 1
|
||||
if self._loss_scale_manager is not None:
|
||||
scaling_sens = self._loss_scale_manager.get_loss_scale()
|
||||
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
|
||||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order):
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
|
||||
epoch_num=1, iter_first_order=1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
|
@ -275,7 +125,7 @@ class Model:
|
|||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
|
@ -283,133 +133,31 @@ class Model:
|
|||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
def init(self, train_dataset=None, valid_dataset=None):
|
||||
"""
|
||||
Initializes compute graphs and data graphs with sink mode.
|
||||
|
||||
Note:
|
||||
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
|
||||
initialized. Default: None.
|
||||
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
|
||||
be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> train_dataset = get_train_dataset()
|
||||
>>> valid_dataset = get_valid_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
||||
>>> model.init(train_dataset, valid_dataset)
|
||||
>>> model.train(2, train_dataset)
|
||||
>>> model.eval(valid_dataset)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||
|
||||
if not train_dataset and not valid_dataset:
|
||||
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
if train_dataset:
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
self._train_network.set_train()
|
||||
self._train_network.phase = 'train'
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._train_network = train_network
|
||||
for inputs in train_dataset_helper:
|
||||
self._train_network.compile(*inputs)
|
||||
break
|
||||
|
||||
if valid_dataset:
|
||||
if not self._metric_fns:
|
||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
for inputs in valid_dataset_helper:
|
||||
self._eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Training.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) will
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
"""
|
||||
epoch = check_int_positive(epoch)
|
||||
self._train_network.set_train()
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
# build callback list
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
cb_params.mode = "train"
|
||||
cb_params.loss_fn = self._loss_fn
|
||||
cb_params.optimizer = self._optimizer
|
||||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = callbacks
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
be returned. The data and label would be passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||
"""
|
||||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
|
||||
iter_first_order = self._frequency - 1
|
||||
iter_second_order = 1
|
||||
train_dataset.__loop_size__ = iter_second_order
|
||||
|
@ -418,308 +166,82 @@ class Model:
|
|||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
sink_size=sink_size,
|
||||
epoch_num=epoch_num,
|
||||
iter_first_order=iter_first_order)
|
||||
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
loop_size = dataset_helper.loop_size()
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
has_do_dataset_init = False
|
||||
switch_branch_one = True
|
||||
index_first_order = 0
|
||||
train_network_init_flag = True
|
||||
has_do_dataset_init = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
list_callback.step_begin(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += loop_size
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
if context.get_context("device_target") == "GPU":
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += 1
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
switch_branch_one = not switch_branch_one
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
else:
|
||||
cb_params.cur_step_num += 1
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
train_network_init_flag = False
|
||||
self._train_network.phase = 'train1'
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
index_first_order += 1
|
||||
if index_first_order == iter_first_order:
|
||||
index_first_order = 0
|
||||
switch_branch_one = not switch_branch_one
|
||||
list_callback.step_end(run_context)
|
||||
else:
|
||||
cb_params.cur_step_num += iter_first_order
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
self._train_network.phase = 'train1'
|
||||
if not has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += 1
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
else:
|
||||
cb_params.cur_step_num += iter_first_order
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
train_network_init_flag = False
|
||||
self._train_network.phase = 'train1'
|
||||
if not has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
dataset_helper.stop_send()
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=False)
|
||||
cb_params.cur_step_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
len_element = len(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
_, overflow, _ = outputs
|
||||
overflow = np.all(overflow.asnumpy())
|
||||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Training API where the iteration is controlled by python front-end.
|
||||
|
||||
When setting pynative mode, the training process will be performed with dataset not sink.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
|
||||
operation in dataset processing. Otherwise, errors could occur since the amount of data
|
||||
is not the amount training requires.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
repeat_count = train_dataset.get_repeat_count()
|
||||
if epoch != repeat_count and dataset_sink_mode is True:
|
||||
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
|
||||
check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
cb_params.eval_network = self._eval_network
|
||||
list_callback.begin(run_context)
|
||||
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
outputs = self._eval_network(*inputs)
|
||||
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
|
||||
return metrics
|
||||
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
dataset_helper, _ = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=False)
|
||||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
return metrics
|
||||
|
||||
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Evaluation API where the iteration is controlled by python front-end.
|
||||
|
||||
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
callbacks (list): List of callback object. Callbacks which should be excuted
|
||||
while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
||||
self._clear_metrics()
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
Data could be single tensor, or list of tensor, tuple of tensor.
|
||||
|
||||
Note:
|
||||
Batch data should be put together in one tensor.
|
||||
|
||||
Args:
|
||||
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
|
||||
|
||||
Returns:
|
||||
Tensor, array(s) of predictions.
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
||||
>>> model = Model(Net())
|
||||
>>> model.predict(input_data)
|
||||
"""
|
||||
self._predict_network.set_train(False)
|
||||
check_input_data(*predict_data, data_class=Tensor)
|
||||
result = self._predict_network(*predict_data)
|
||||
|
||||
check_output_data(result)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["Model"]
|
||||
__all__ = ["Model_Thor"]
|
||||
|
|
|
@ -1,262 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = _weight_variable(weight_shape)
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1)
|
||||
self.bn1 = _bn(channel)
|
||||
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
|
||||
_bn(out_channel)])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = _conv7x7(3, 64, stride=2)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3])
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
c4 = self.layer3(c3)
|
||||
c5 = self.layer4(c4)
|
||||
|
||||
out = self.mean(c5, (2, 3))
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def resnet50(class_num=10):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -18,8 +18,9 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import context
|
||||
|
||||
from src.thor_layer import Conv2d_Thor, Dense_Thor
|
||||
from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
|
@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
|||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
|
@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
|||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
return Tensor(init_value)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
layer = Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
else:
|
||||
layer = Conv2d_Thor_GPU(in_channel, out_channel,
|
||||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
return layer
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
|
||||
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
layer = Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
else:
|
||||
layer = Conv2d_Thor_GPU(in_channel, out_channel,
|
||||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
return layer
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
|
||||
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
|
||||
return Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
layer = Conv2d_Thor(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
else:
|
||||
layer = Conv2d_Thor_GPU(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
|
||||
return layer
|
||||
|
||||
|
||||
def _bn(channel):
|
||||
|
@ -120,14 +144,21 @@ def _bn(channel):
|
|||
|
||||
def _bn_last(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel, damping, loss_scale, frequency):
|
||||
def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32):
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
|
||||
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
|
||||
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
else:
|
||||
layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight,
|
||||
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
return layer
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
|
@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell):
|
|||
stride=1,
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=278):
|
||||
frequency=278,
|
||||
batch_size=32):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency, batch_size=batch_size)
|
||||
self.bn1 = _bn(channel)
|
||||
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency, batch_size=batch_size)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency, batch_size=batch_size)
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell):
|
|||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency),
|
||||
frequency=frequency,
|
||||
batch_size=batch_size),
|
||||
_bn(out_channel)])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
|
@ -239,16 +272,19 @@ class ResNet(nn.Cell):
|
|||
num_classes,
|
||||
damping,
|
||||
loss_scale,
|
||||
frequency):
|
||||
frequency,
|
||||
batch_size):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency, batch_size=batch_size)
|
||||
self.bn1 = _bn(64)
|
||||
self.relu = P.ReLU()
|
||||
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
|
||||
# self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
|
@ -257,7 +293,8 @@ class ResNet(nn.Cell):
|
|||
stride=strides[0],
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
|
@ -265,14 +302,16 @@ class ResNet(nn.Cell):
|
|||
stride=strides[1],
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2], damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
|
@ -280,14 +319,16 @@ class ResNet(nn.Cell):
|
|||
stride=strides[3],
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale,
|
||||
frequency=frequency, batch_size=batch_size)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
|
||||
damping, loss_scale, frequency):
|
||||
damping, loss_scale, frequency, batch_size):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
|
@ -307,12 +348,14 @@ class ResNet(nn.Cell):
|
|||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
layers.append(resnet_block)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1,
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency)
|
||||
damping=damping, loss_scale=loss_scale, frequency=frequency,
|
||||
batch_size=batch_size)
|
||||
layers.append(resnet_block)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
@ -321,7 +364,7 @@ class ResNet(nn.Cell):
|
|||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
c1, _ = self.maxpool(x)
|
||||
c1 = self.maxpool(x)
|
||||
|
||||
c2 = self.layer1(c1)
|
||||
c3 = self.layer2(c2)
|
||||
|
@ -335,7 +378,7 @@ class ResNet(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
|
||||
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
|
@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
|
|||
class_num,
|
||||
damping,
|
||||
loss_scale,
|
||||
frequency)
|
||||
frequency,
|
||||
batch_size)
|
||||
|
|
|
@ -12,27 +12,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
"""THOR"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
|
||||
from src.grad_reducer_thor import DistributedGradReducerThor
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
||||
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
@ -46,6 +39,119 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
|||
return gradient
|
||||
|
||||
|
||||
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
class THOR_GPU(Optimizer):
|
||||
"""
|
||||
THOR
|
||||
"""
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
|
||||
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
|
||||
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.use_nesterov = check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
|
||||
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
|
||||
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
|
||||
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
|
||||
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
|
||||
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
|
||||
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
|
||||
1.0 / 196, 1.0 / 196, 1.0 / 196,
|
||||
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
|
||||
1.0]
|
||||
self.feature_map_new = [x ** 0.5 for x in self.feature_map]
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul()
|
||||
self.matrix_A = ParameterTuple(matrix_A)
|
||||
self.matrix_G = ParameterTuple(matrix_G)
|
||||
self.A_inv_max = ParameterTuple(A_inv_max)
|
||||
self.G_inv_max = ParameterTuple(G_inv_max)
|
||||
self.assign = P.Assign()
|
||||
self.mul = P.Mul()
|
||||
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
|
||||
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)
|
||||
self.weight_decay = weight_decay
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.update_gradient = P.UpdateThorGradient(split_dim=128)
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
gradients = self.scale_grad(gradients)
|
||||
new_grads = ()
|
||||
if self.thor:
|
||||
matrix_A_allreduce = ()
|
||||
matrix_G_allreduce = ()
|
||||
for i in range(54):
|
||||
g = gradients[i * 3]
|
||||
matrix_A = self.matrix_A[i]
|
||||
matrix_G = self.matrix_G[i]
|
||||
matrix_A = F.depend(matrix_A, g)
|
||||
matrix_G = F.depend(matrix_G, g)
|
||||
matrix_A = self.mul(matrix_A, self.feature_map_new[i])
|
||||
matrix_G = self.mul(matrix_G, self.feature_map_new[i])
|
||||
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
|
||||
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
|
||||
matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
|
||||
matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
|
||||
for i in range(54):
|
||||
g = gradients[i * 3]
|
||||
g_shape = self.shape(g)
|
||||
g = self.reshape(g, (g_shape[0], -1))
|
||||
matrix_A = matrix_A_allreduce[i]
|
||||
matrix_G = matrix_G_allreduce[i]
|
||||
g = self.update_gradient(matrix_G, g, matrix_A)
|
||||
fake_A = self.assign(self.matrix_A[i], matrix_A)
|
||||
fake_G = self.assign(self.matrix_G[i], matrix_G)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
if i == 53:
|
||||
new_grads = new_grads + (g,)
|
||||
else:
|
||||
g = self.reshape(g, g_shape)
|
||||
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
|
||||
else:
|
||||
for i in range(54):
|
||||
g = gradients[i * 3]
|
||||
g_shape = self.shape(g)
|
||||
g = self.reshape(g, (g_shape[0], -1))
|
||||
matrix_A = self.matrix_A[i]
|
||||
matrix_G = self.matrix_G[i]
|
||||
matrix_A = F.depend(matrix_A, g)
|
||||
matrix_G = F.depend(matrix_G, g)
|
||||
g = self.update_gradient(matrix_G, g, matrix_A)
|
||||
if i == 53:
|
||||
new_grads = new_grads + (g,)
|
||||
else:
|
||||
g = self.reshape(g, g_shape)
|
||||
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
|
||||
gradients = new_grads
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
||||
return success
|
||||
|
||||
class THOR(Optimizer):
|
||||
"""THOR"""
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
|
||||
|
@ -195,5 +301,5 @@ class THOR(Optimizer):
|
|||
params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
|
||||
return success
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""thor_layer"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||
from mindspore._extends import cell_attr_register
|
||||
|
@ -25,8 +24,10 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
C0 = 16
|
||||
|
||||
|
||||
def caculate_device_shape(matrix_dim, channel, is_A):
|
||||
ll = (0)
|
||||
if is_A:
|
||||
|
@ -37,6 +38,31 @@ def caculate_device_shape(matrix_dim, channel, is_A):
|
|||
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
|
||||
return ll
|
||||
|
||||
|
||||
def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
|
||||
split_dimA = split_dim
|
||||
split_dimG = split_dim
|
||||
if matrix_A_dim % split_dim == 0:
|
||||
batch_w = matrix_A_dim // split_dim
|
||||
else:
|
||||
if matrix_A_dim < split_dim:
|
||||
batch_w = 1
|
||||
split_dimA = matrix_A_dim
|
||||
else:
|
||||
batch_w = matrix_A_dim // split_dim + 1
|
||||
|
||||
if matrix_G_dim % split_dim == 0:
|
||||
batch_h = matrix_G_dim // split_dim
|
||||
else:
|
||||
if matrix_G_dim < split_dim:
|
||||
batch_h = 1
|
||||
split_dimG = matrix_G_dim
|
||||
else:
|
||||
batch_h = matrix_G_dim // split_dim + 1
|
||||
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
|
||||
matrix_G_shape = (batch_h, split_dimG, split_dimG)
|
||||
return matrix_A_shape, matrix_G_shape
|
||||
|
||||
class _Conv(Cell):
|
||||
r"""Applies a N-D convolution over an input signal composed of several input
|
||||
planes.
|
||||
|
@ -97,6 +123,286 @@ class _Conv(Cell):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class Conv2d_Thor_GPU(_Conv):
|
||||
"""Conv2d_Thor"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format='NCHW',
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=278,
|
||||
batch_size=32,
|
||||
bias_init='zeros'):
|
||||
self.thor = True
|
||||
self.hw = kernel_size * kernel_size
|
||||
kernel_size = twice(kernel_size)
|
||||
super(Conv2d_Thor_GPU, self).__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
data_format,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init,
|
||||
)
|
||||
self.conv2d = P.Conv2D(out_channel=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
mode=1,
|
||||
pad_mode=self.pad_mode,
|
||||
pad=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=self.group
|
||||
)
|
||||
|
||||
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
|
||||
self.matrix_G_dim = self.out_channels
|
||||
|
||||
split_dim = 128
|
||||
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.matrix_A_dim, self.matrix_G_dim, split_dim)
|
||||
self.matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32),
|
||||
name='matrix_A_inv', requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32),
|
||||
name='matrix_A_inv', requires_grad=False)
|
||||
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.batch_size = Tensor(batch_size, mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.gather = P.GatherV2()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.sqrt = P.Sqrt()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False)
|
||||
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
|
||||
self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32)
|
||||
self.cholesky = P.Cholesky(split_dim=split_dim)
|
||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, self.batch_size)
|
||||
dout = self.reduce_mean(dout, 0)
|
||||
dout_shape = self.shape(dout)
|
||||
dout = self.reshape(dout, (dout_shape[0], -1))
|
||||
dout_shape = self.shape(dout)
|
||||
normalizer = dout_shape[1]
|
||||
dout = self.cast(dout, mstype.float32)
|
||||
matrix_G = self.matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.mul(damping_step, 1.0 / normalizer)
|
||||
damping = self.sqrt(damping)
|
||||
matrix_G = matrix_G + damping * self.dampingG
|
||||
matrix_G = self.cholesky(matrix_G)
|
||||
matrix_G = self.vector_matmul(matrix_G, matrix_G)
|
||||
self.matrix_G_inv = matrix_G
|
||||
return out
|
||||
|
||||
def construct(self, x):
|
||||
if self.thor:
|
||||
matrix_A = self.img2col(x)
|
||||
matrix_A_shape = self.shape(matrix_A)
|
||||
matrix_A = self.reshape(matrix_A, (matrix_A_shape[0]*matrix_A_shape[1]*matrix_A_shape[2],
|
||||
matrix_A_shape[3], -1))
|
||||
matrix_A = self.reduce_mean(matrix_A, 1)
|
||||
matrix_A_shape = self.shape(matrix_A)
|
||||
normalizer = matrix_A_shape[1]
|
||||
matrix_A = self.cast(matrix_A, mstype.float32)
|
||||
matrix_A = self.matmul(matrix_A, matrix_A)
|
||||
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.mul(damping_step, 1.0 / normalizer)
|
||||
damping = self.sqrt(damping)
|
||||
matrix_A = matrix_A + damping * self.dampingA
|
||||
matrix_A = self.cholesky(matrix_A)
|
||||
matrix_A = self.vector_matmul(matrix_A, matrix_A)
|
||||
matrix_A = self.broadcast_to(matrix_A)
|
||||
self.matrix_A_inv = matrix_A
|
||||
out = self.conv2d(x, self.weight)
|
||||
out = self.getG(out)
|
||||
else:
|
||||
out = self.conv2d(x, self.weight)
|
||||
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
"""extra_repr"""
|
||||
s = 'input_channels={}, output_channels={}, kernel_size={},' \
|
||||
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
|
||||
'group={}, data_format={}, has_bias={},' \
|
||||
'weight_init={}, bias_init={}'.format(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
self.kernel_size,
|
||||
self.stride,
|
||||
self.pad_mode,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.group,
|
||||
self.data_format,
|
||||
self.has_bias,
|
||||
self.weight,
|
||||
self.bias)
|
||||
|
||||
if self.has_bias:
|
||||
s += ', bias={}'.format(self.bias)
|
||||
return s
|
||||
|
||||
|
||||
class Dense_Thor_GPU(Cell):
|
||||
"""Dense_Thor"""
|
||||
@cell_attr_register(attrs=['has_bias', 'activation'])
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=278,
|
||||
batch_size=32,
|
||||
has_bias=True,
|
||||
activation=None):
|
||||
super(Dense_Thor_GPU, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
split_dim = 128
|
||||
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.in_channels, self.out_channels, split_dim)
|
||||
self.matrix_A_inv = Parameter(Tensor(np.zeros(matrix_A_shape).astype(np.float32)),
|
||||
name='matrix_A_inv', requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(Tensor(np.zeros(matrix_G_shape).astype(np.float32)),
|
||||
name="matrix_G_inv", requires_grad=False)
|
||||
self.broadcast_to = P.BroadcastTo(matrix_A_shape)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.mul = P.Mul()
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.batch_size = Tensor(batch_size, mstype.float16)
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False)
|
||||
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.gather = P.GatherV2()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.add = P.TensorAdd()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.cholesky = P.Cholesky(split_dim=split_dim)
|
||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, self.batch_size)
|
||||
dout_shape = self.shape(dout)
|
||||
normalizer = dout_shape[0]
|
||||
dout = self.cast(dout, mstype.float32)
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.sqrt(damping_step)
|
||||
matrix_G = matrix_G + damping * self.dampingG
|
||||
matrix_G = self.cholesky(matrix_G)
|
||||
matrix_G = self.vector_matmul(matrix_G, matrix_G)
|
||||
self.matrix_G_inv = matrix_G
|
||||
return out
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
if self.thor:
|
||||
inputs = self.cast(x, mstype.float32)
|
||||
inputs = self.cube_matmul(inputs, inputs)
|
||||
inputs_shape = self.shape(inputs)
|
||||
normalizer = inputs_shape[0]
|
||||
matrix_A = self.mul(inputs, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
matrix_A = matrix_A + damping * self.dampingA
|
||||
matrix_A = self.cholesky(matrix_A)
|
||||
matrix_A = self.vector_matmul(matrix_A, matrix_A)
|
||||
matrix_A = self.broadcast_to(matrix_A)
|
||||
self.matrix_A_inv = matrix_A
|
||||
output = self.matmul(x, self.weight)
|
||||
output = self.getG(output)
|
||||
else:
|
||||
output = self.matmul(x, self.weight)
|
||||
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
"""extend_repr"""
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
|
||||
if self.activation_flag:
|
||||
str_info = str_info + ', activation={}'.format(self.activation)
|
||||
|
||||
return str_info
|
||||
|
||||
|
||||
class Conv2d_Thor(_Conv):
|
||||
"""Conv2d_Thor"""
|
||||
def __init__(self,
|
||||
|
@ -114,6 +420,7 @@ class Conv2d_Thor(_Conv):
|
|||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=278,
|
||||
batch_size=32,
|
||||
bias_init='zeros'):
|
||||
self.thor = True
|
||||
ksizes = (1, kernel_size, kernel_size, 1)
|
||||
|
@ -143,7 +450,7 @@ class Conv2d_Thor(_Conv):
|
|||
dilation=self.dilation,
|
||||
group=self.group
|
||||
)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
|
@ -228,7 +535,7 @@ class Conv2d_Thor(_Conv):
|
|||
normalizer = dout_shape[0]
|
||||
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
normalizer = self.cast(normalizer, ms.float32)
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
|
@ -261,7 +568,7 @@ class Conv2d_Thor(_Conv):
|
|||
matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0))
|
||||
matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels))
|
||||
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
|
||||
normalizer = self.cast(normalizer, ms.float32)
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
|
||||
if self.padA_flag:
|
||||
matrix_A = self.padA(matrix_A)
|
||||
|
@ -330,6 +637,7 @@ class Dense_Thor(Cell):
|
|||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=278,
|
||||
batch_size=32,
|
||||
has_bias=True,
|
||||
activation=None):
|
||||
super(Dense_Thor, self).__init__()
|
||||
|
@ -337,6 +645,7 @@ class Dense_Thor(Cell):
|
|||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.thor = True
|
||||
self.batch_size = batch_size
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
|
@ -376,8 +685,8 @@ class Dense_Thor(Cell):
|
|||
self.damping = Tensor(damping)
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.pad = P.Pad(((0, 24), (0, 24)))
|
||||
self.pad1 = P.Pad(((0, 8), (0, 8)))
|
||||
self.pad = P.Pad(((0, 23), (0, 23)))
|
||||
self.pad1 = P.Pad(((0, 7), (0, 7)))
|
||||
self.slice = P.Slice()
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
|
@ -385,7 +694,7 @@ class Dense_Thor(Cell):
|
|||
self.axis = 0
|
||||
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
|
||||
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
|
||||
self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000])
|
||||
self.fused_abs_max1 = P.CusFusedAbsMax1([1001, 1001])
|
||||
self.fused_abs_max2 = P.CusFusedAbsMax1()
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
|
@ -402,7 +711,7 @@ class Dense_Thor(Cell):
|
|||
dout = self.mul(dout, 32.0)
|
||||
normalizer = 32
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
normalizer = self.cast(normalizer, ms.float32)
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
matrix_G = self.pad(matrix_G)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
|
@ -417,7 +726,7 @@ class Dense_Thor(Cell):
|
|||
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
|
||||
self.G_inv_max = matrix_G_inv_max
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000))
|
||||
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
|
||||
matrix_G_inv = self.pad1(matrix_G_inv)
|
||||
matrix_G_inv_shape = self.shape(matrix_G_inv)
|
||||
matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16))
|
||||
|
@ -431,7 +740,7 @@ class Dense_Thor(Cell):
|
|||
if self.thor:
|
||||
inputs = self.cube_matmul(x, x)
|
||||
normalizer = 32
|
||||
normalizer = self.cast(normalizer, ms.float32)
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_A = self.mul(inputs, 1.0 / normalizer)
|
||||
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
|
|
|
@ -12,44 +12,46 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import argparse
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore import Tensor
|
||||
from mindspore import dataset as de
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.model import ParallelMode
|
||||
from src.model_thor import Model
|
||||
from src.resnet_thor import resnet50
|
||||
from src.thor import THOR
|
||||
from src.config import config
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.dataset_imagenet import create_dataset
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
from src.model_thor import Model_Thor as Model
|
||||
from src.resnet_thor import resnet50
|
||||
from src.dataset import create_dataset
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
|
||||
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num')
|
||||
args_opt = parser.parse_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
if args_opt.device_target == "Ascend":
|
||||
from src.thor import THOR
|
||||
from src.config import config
|
||||
else:
|
||||
from src.thor import THOR_GPU as THOR
|
||||
from src.config import config_gpu as config
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
|
||||
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
|
||||
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
|
||||
"""get_model_lr"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
|
@ -57,9 +59,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
|
|||
epoch = (i + 1) / steps_per_epoch
|
||||
base = (1.0 - float(epoch) / total_epochs) ** decay
|
||||
lr_local = lr_init * base
|
||||
if epoch >= 39:
|
||||
if epoch >= decay_epochs:
|
||||
lr_local = lr_local * 0.5
|
||||
if epoch >= 40:
|
||||
if epoch >= decay_epochs + 1:
|
||||
lr_local = lr_local * 0.5
|
||||
lr_each_step.append(lr_local)
|
||||
current_step = global_step
|
||||
|
@ -76,7 +78,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
|
|||
epoch = (step + 1) / steps_per_epoch
|
||||
damping_here = damping_init * (decay_rate ** (epoch / 10))
|
||||
damping_each_step.append(damping_here)
|
||||
|
||||
current_step = global_step
|
||||
damping_each_step = np.array(damping_each_step).astype(np.float32)
|
||||
damping_now = damping_each_step[current_step:]
|
||||
|
@ -84,49 +85,70 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not args_opt.do_eval and args_opt.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True, parameter_broadcast=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
|
||||
target = args_opt.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
|
||||
init()
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
damping = get_model_damping(0, 0.03, 0.87, 50, 5004)
|
||||
if args_opt.run_distribute:
|
||||
# Ascend target
|
||||
if target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
|
||||
init()
|
||||
# GPU target
|
||||
else:
|
||||
init("nccl")
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
|
||||
batch_size=config.batch_size, target=target)
|
||||
|
||||
# define net
|
||||
step_size = dataset.get_dataset_size()
|
||||
damping = get_model_damping(0, config.damping_init, config.damping_decay, 90, step_size)
|
||||
lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
|
||||
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
|
||||
frequency=config.frequency)
|
||||
frequency=config.frequency, batch_size=config.batch_size)
|
||||
|
||||
if not config.label_smooth:
|
||||
# define loss, model
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
batch_size=config.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
|
||||
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
|
||||
config.weight_decay, config.loss_scale)
|
||||
|
||||
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
|
||||
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
|
||||
config.weight_decay, config.loss_scale)
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
if target == "Ascend":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
|
||||
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency)
|
||||
else:
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency)
|
||||
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
# train model
|
||||
model.train(config.epoch_size, dataset, callbacks=cb)
|
||||
|
|
Loading…
Reference in New Issue