forked from mindspore-Ecosystem/mindspore
!6191 add imagenet for alexnet
Merge pull request !6191 from wukesong/imagenet-alexent
This commit is contained in:
commit
f7c5c5265f
|
@ -20,8 +20,8 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
|||
|
||||
import ast
|
||||
import argparse
|
||||
from src.config import alexnet_cfg as cfg
|
||||
from src.dataset import create_dataset_cifar10
|
||||
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
from src.alexnet import AlexNet
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
|
@ -32,28 +32,50 @@ from mindspore.nn.metrics import Accuracy
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
|
||||
default=True, help='dataset_sink_mode is False or True')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
repeat_size = cfg.epoch_size
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset_cifar10(args.data_path,
|
||||
cfg.batch_size,
|
||||
status="test")
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== {} ==============".format(acc))
|
||||
|
||||
if args.dataset_name == 'cifar10':
|
||||
cfg = alexnet_cifar10_cfg
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
|
||||
ds_eval = create_dataset_cifar10(args.data_path, cfg.batch_size, status="test", target=args.device_target)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
print("load checkpoint from [{}].".format(args.ckpt_path))
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
elif args.dataset_name == 'imagenet':
|
||||
cfg = alexnet_imagenet_cfg
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False)
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
print("load checkpoint from [{}].".format(args.ckpt_path))
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
|
||||
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("result : {}".format(result))
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet] [DATA_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
RANK_TABLE_FILE=$(realpath $1)
|
||||
export RANK_TABLE_FILE
|
||||
export DATASET_NAME=$2
|
||||
export DATA_PATH=$3
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i --dataset_name=$DATASET_NAME --data_path=$DATA_PATH > log 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -17,6 +17,4 @@
|
|||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
|
||||
python -s ${self_path}/../eval.py --device_target="Ascend" > log.txt 2>&1 &
|
||||
|
|
|
@ -17,6 +17,4 @@
|
|||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
|
||||
python -s ${self_path}/../eval.py --device_target="GPU" > log.txt 2>&1 &
|
||||
|
|
7
model_zoo/official/cv/alexnet/scripts/run_standalone_train_ascend.sh
Executable file → Normal file
7
model_zoo/official/cv/alexnet/scripts/run_standalone_train_ascend.sh
Executable file → Normal file
|
@ -14,9 +14,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=1
|
||||
|
||||
# an simple tutorial, more
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
|
||||
python -s ${self_path}/../train.py --device_target="Ascend" > log.txt 2>&1 &
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=1
|
||||
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
|
||||
python -s ${self_path}/../train.py --device_target="GPU" > log.txt 2>&1 &
|
||||
|
|
|
@ -51,6 +51,7 @@ class AlexNet(nn.Cell):
|
|||
self.fc3 = fc_with_initialize(4096, num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
"""define network"""
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
|
|
|
@ -18,7 +18,7 @@ network config setting, will be used in train.py
|
|||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
alexnet_cfg = edict({
|
||||
alexnet_cifar10_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'learning_rate': 0.002,
|
||||
'momentum': 0.9,
|
||||
|
@ -30,3 +30,31 @@ alexnet_cfg = edict({
|
|||
'save_checkpoint_steps': 1562,
|
||||
'keep_checkpoint_max': 10,
|
||||
})
|
||||
|
||||
alexnet_imagenet_cfg = edict({
|
||||
'num_classes': 1000,
|
||||
'learning_rate': 0.13,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 150,
|
||||
'batch_size': 256,
|
||||
'buffer_size': None, # invalid parameter
|
||||
'image_height': 227,
|
||||
'image_width': 227,
|
||||
'save_checkpoint_steps': 625,
|
||||
'keep_checkpoint_max': 10,
|
||||
|
||||
# opt
|
||||
'weight_decay': 0.0001,
|
||||
'loss_scale': 1024,
|
||||
|
||||
# lr
|
||||
'is_dynamic_loss_scale': 0,
|
||||
'label_smooth': 1,
|
||||
'label_smooth_factor': 0.1,
|
||||
'lr_scheduler': 'cosine_annealing',
|
||||
'warmup_epochs': 5,
|
||||
'lr_epochs': [30, 60, 90, 120],
|
||||
'lr_gamma': 0.1,
|
||||
'T_max': 150,
|
||||
'eta_min': 0.0,
|
||||
})
|
||||
|
|
|
@ -16,20 +16,32 @@
|
|||
Produce the dataset
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
from mindspore.common import dtype as mstype
|
||||
from .config import alexnet_cfg as cfg
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from .config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train"):
|
||||
def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train", target="Ascend"):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
cifar_ds = ds.Cifar10Dataset(data_path)
|
||||
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if target != "Ascend" or device_num == 1:
|
||||
cifar_ds = ds.Cifar10Dataset(data_path)
|
||||
else:
|
||||
cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
cfg = alexnet_cifar10_cfg
|
||||
|
||||
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
|
@ -39,16 +51,97 @@ def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="trai
|
|||
random_horizontal_op = CV.RandomHorizontalFlip()
|
||||
channel_swap_op = CV.HWC2CHW()
|
||||
typecast_op = C.TypeCast(mstype.int32)
|
||||
cifar_ds = cifar_ds.map(operations=typecast_op, input_columns="label")
|
||||
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=8)
|
||||
if status == "train":
|
||||
cifar_ds = cifar_ds.map(operations=random_crop_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(operations=random_horizontal_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(operations=resize_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(operations=rescale_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(operations=normalize_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(operations=channel_swap_op, input_columns="image")
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=8)
|
||||
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=8)
|
||||
|
||||
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
|
||||
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
|
||||
cifar_ds = cifar_ds.repeat(repeat_size)
|
||||
return cifar_ds
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, batch_size=32, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
cfg = alexnet_imagenet_cfg
|
||||
|
||||
if num_parallel_workers is None:
|
||||
num_parallel_workers = int(64 / device_num)
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
assert cfg.image_height == cfg.image_width, "imagenet_cfg.image_height not equal imagenet_cfg.image_width"
|
||||
image_size = cfg.image_height
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if training:
|
||||
transform_img = [
|
||||
CV.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
CV.RandomHorizontalFlip(prob=0.5),
|
||||
CV.Normalize(mean=mean, std=std),
|
||||
CV.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
CV.Decode(),
|
||||
CV.Resize((256, 256)),
|
||||
CV.CenterCrop(image_size),
|
||||
CV.Normalize(mean=mean, std=std),
|
||||
CV.HWC2CHW()
|
||||
]
|
||||
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
|
||||
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers,
|
||||
operations=transform_img)
|
||||
data_set = data_set.map(input_columns="label", num_parallel_workers=num_parallel_workers,
|
||||
operations=transform_label)
|
||||
|
||||
num_parallel_workers2 = int(16 / device_num)
|
||||
data_set = data_set.batch(batch_size, num_parallel_workers=num_parallel_workers2, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
else:
|
||||
# rank_size = rank_id = None
|
||||
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
|
||||
def get_lr_cifar10(current_step, lr_max, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
|
@ -42,3 +43,85 @@ def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
|
|||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
||||
def get_lr_imagenet(cfg, steps_per_epoch):
|
||||
"""generate learning rate array"""
|
||||
if cfg.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(cfg.learning_rate,
|
||||
cfg.lr_epochs,
|
||||
steps_per_epoch,
|
||||
cfg.warmup_epochs,
|
||||
cfg.epoch_size,
|
||||
gamma=cfg.lr_gamma,
|
||||
)
|
||||
elif cfg.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(cfg.learning_rate,
|
||||
steps_per_epoch,
|
||||
cfg.warmup_epochs,
|
||||
cfg.epoch_size,
|
||||
cfg.T_max,
|
||||
cfg.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(cfg.lr_scheduler)
|
||||
|
||||
return lr
|
||||
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""Linear learning rate"""
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""Linear warm up learning rate"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
""" Cosine annealing learning rate"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""get parameters for Momentum optimizer"""
|
||||
def get_param_groups(network):
|
||||
"""get parameters"""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
|
@ -20,14 +20,18 @@ python train.py --data_path /YourDataPath
|
|||
|
||||
import ast
|
||||
import argparse
|
||||
from src.config import alexnet_cfg as cfg
|
||||
from src.dataset import create_dataset_cifar10
|
||||
from src.generator_lr import get_lr
|
||||
import os
|
||||
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
from src.generator_lr import get_lr_cifar10, get_lr_imagenet
|
||||
from src.alexnet import AlexNet
|
||||
from src.get_param_groups import get_param_groups
|
||||
import mindspore.nn as nn
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
@ -36,27 +40,102 @@ set_seed(1)
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
|
||||
default=True, help='dataset_sink_mode is False or True')
|
||||
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.dataset_name == "cifar10":
|
||||
cfg = alexnet_cifar10_cfg
|
||||
elif args.dataset_name == "imagenet":
|
||||
cfg = alexnet_imagenet_cfg
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
context.set_context(save_graphs=False)
|
||||
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
||||
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
elif device_target == "GPU":
|
||||
init()
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if args.dataset_name == "cifar10":
|
||||
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, target=args.device_target)
|
||||
elif args.dataset_name == "imagenet":
|
||||
ds_train = create_dataset_imagenet(args.data_path, cfg.batch_size)
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
|
||||
network = AlexNet(cfg.num_classes)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
|
||||
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
|
||||
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
loss_scale_manager = None
|
||||
metrics = None
|
||||
if args.dataset_name == 'cifar10':
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
lr = Tensor(get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
|
||||
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
|
||||
metrics = {"Accuracy": Accuracy()}
|
||||
|
||||
elif args.dataset_name == 'imagenet':
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
||||
lr = Tensor(get_lr_imagenet(cfg, ds_train.get_dataset_size()))
|
||||
|
||||
opt = nn.Momentum(params=get_param_groups(network),
|
||||
learning_rate=lr,
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale)
|
||||
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
if cfg.is_dynamic_loss_scale == 1:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
if device_target == "Ascend":
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False,
|
||||
loss_scale_manager=loss_scale_manager)
|
||||
elif device_target == "GPU":
|
||||
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
if device_num > 1:
|
||||
ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank()))
|
||||
else:
|
||||
ckpt_save_dir = args.ckpt_path
|
||||
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=ckpt_save_dir, config=config_ck)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
||||
|
|
Loading…
Reference in New Issue