!6191 add imagenet for alexnet

Merge pull request !6191 from wukesong/imagenet-alexent
This commit is contained in:
mindspore-ci-bot 2020-09-17 11:28:46 +08:00 committed by Gitee
commit f7c5c5265f
12 changed files with 444 additions and 53 deletions

View File

@ -20,8 +20,8 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import ast
import argparse
from src.config import alexnet_cfg as cfg
from src.dataset import create_dataset_cifar10
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.alexnet import AlexNet
import mindspore.nn as nn
from mindspore import context
@ -32,28 +32,50 @@ from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
default=True, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset_cifar10(args.data_path,
cfg.batch_size,
status="test")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))
if args.dataset_name == 'cifar10':
cfg = alexnet_cifar10_cfg
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
ds_eval = create_dataset_cifar10(args.data_path, cfg.batch_size, status="test", target=args.device_target)
param_dict = load_checkpoint(args.ckpt_path)
print("load checkpoint from [{}].".format(args.ckpt_path))
load_param_into_net(network, param_dict)
network.set_train(False)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
elif args.dataset_name == 'imagenet':
cfg = alexnet_imagenet_cfg
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False)
param_dict = load_checkpoint(args.ckpt_path)
print("load checkpoint from [{}].".format(args.ckpt_path))
load_param_into_net(network, param_dict)
network.set_train(False)
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else:
raise ValueError("Unsupport dataset.")
result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("result : {}".format(result))

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet] [DATA_PATH]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
export DATASET_NAME=$2
export DATA_PATH=$3
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp ./train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_id=$i --dataset_name=$DATASET_NAME --data_path=$DATA_PATH > log 2>&1 &
cd ..
done

View File

@ -17,6 +17,4 @@
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../eval.py --device_target="Ascend" > log.txt 2>&1 &

View File

@ -17,6 +17,4 @@
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../eval.py --device_target="GPU" > log.txt 2>&1 &

View File

@ -14,9 +14,10 @@
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export RANK_SIZE=1
# an simple tutorial, more
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="Ascend" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../train.py --device_target="Ascend" > log.txt 2>&1 &

View File

@ -14,9 +14,10 @@
# limitations under the License.
# ============================================================================
export DEVICE_NUM=1
export RANK_SIZE=1
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../train.py --data_path=./$DATA_PATH --device_target="GPU" --ckpt_path=./$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../train.py --device_target="GPU" > log.txt 2>&1 &

View File

@ -51,6 +51,7 @@ class AlexNet(nn.Cell):
self.fc3 = fc_with_initialize(4096, num_classes)
def construct(self, x):
"""define network"""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)

View File

@ -18,7 +18,7 @@ network config setting, will be used in train.py
from easydict import EasyDict as edict
alexnet_cfg = edict({
alexnet_cifar10_cfg = edict({
'num_classes': 10,
'learning_rate': 0.002,
'momentum': 0.9,
@ -30,3 +30,31 @@ alexnet_cfg = edict({
'save_checkpoint_steps': 1562,
'keep_checkpoint_max': 10,
})
alexnet_imagenet_cfg = edict({
'num_classes': 1000,
'learning_rate': 0.13,
'momentum': 0.9,
'epoch_size': 150,
'batch_size': 256,
'buffer_size': None, # invalid parameter
'image_height': 227,
'image_width': 227,
'save_checkpoint_steps': 625,
'keep_checkpoint_max': 10,
# opt
'weight_decay': 0.0001,
'loss_scale': 1024,
# lr
'is_dynamic_loss_scale': 0,
'label_smooth': 1,
'label_smooth_factor': 0.1,
'lr_scheduler': 'cosine_annealing',
'warmup_epochs': 5,
'lr_epochs': [30, 60, 90, 120],
'lr_gamma': 0.1,
'T_max': 150,
'eta_min': 0.0,
})

View File

@ -16,20 +16,32 @@
Produce the dataset
"""
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.common import dtype as mstype
from .config import alexnet_cfg as cfg
from mindspore.communication.management import get_rank, get_group_size
from .config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train"):
def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train", target="Ascend"):
"""
create dataset for train or test
"""
cifar_ds = ds.Cifar10Dataset(data_path)
if target == "Ascend":
device_num, rank_id = _get_rank_info()
if target != "Ascend" or device_num == 1:
cifar_ds = ds.Cifar10Dataset(data_path)
else:
cifar_ds = ds.Cifar10Dataset(data_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
rescale = 1.0 / 255.0
shift = 0.0
cfg = alexnet_cifar10_cfg
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift)
@ -39,16 +51,97 @@ def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="trai
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(operations=typecast_op, input_columns="label")
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=8)
if status == "train":
cifar_ds = cifar_ds.map(operations=random_crop_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=random_horizontal_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=resize_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=rescale_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=normalize_op, input_columns="image")
cifar_ds = cifar_ds.map(operations=channel_swap_op, input_columns="image")
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=8)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=8)
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
cifar_ds = cifar_ds.repeat(repeat_size)
return cifar_ds
def create_dataset_imagenet(dataset_path, batch_size=32, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None, sampler=None, class_indexing=None):
"""
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
cfg = alexnet_imagenet_cfg
if num_parallel_workers is None:
num_parallel_workers = int(64 / device_num)
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=device_num, shard_id=rank_id)
assert cfg.image_height == cfg.image_width, "imagenet_cfg.image_height not equal imagenet_cfg.image_width"
image_size = cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if training:
transform_img = [
CV.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
CV.RandomHorizontalFlip(prob=0.5),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
else:
transform_img = [
CV.Decode(),
CV.Resize((256, 256)),
CV.CenterCrop(image_size),
CV.Normalize(mean=mean, std=std),
CV.HWC2CHW()
]
transform_label = [C.TypeCast(mstype.int32)]
data_set = data_set.map(input_columns="image", num_parallel_workers=num_parallel_workers,
operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=num_parallel_workers,
operations=transform_label)
num_parallel_workers2 = int(16 / device_num)
data_set = data_set.batch(batch_size, num_parallel_workers=num_parallel_workers2, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
# rank_size = rank_id = None
rank_size = 1
rank_id = 0
return rank_size, rank_id

87
model_zoo/official/cv/alexnet/src/generator_lr.py Executable file → Normal file
View File

@ -13,10 +13,11 @@
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
from collections import Counter
import numpy as np
def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
def get_lr_cifar10(current_step, lr_max, total_epochs, steps_per_epoch):
"""
generate learning rate array
@ -42,3 +43,85 @@ def get_lr(current_step, lr_max, total_epochs, steps_per_epoch):
learning_rate = lr_each_step[current_step:]
return learning_rate
def get_lr_imagenet(cfg, steps_per_epoch):
"""generate learning rate array"""
if cfg.lr_scheduler == 'exponential':
lr = warmup_step_lr(cfg.learning_rate,
cfg.lr_epochs,
steps_per_epoch,
cfg.warmup_epochs,
cfg.epoch_size,
gamma=cfg.lr_gamma,
)
elif cfg.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(cfg.learning_rate,
steps_per_epoch,
cfg.warmup_epochs,
cfg.epoch_size,
cfg.T_max,
cfg.eta_min)
else:
raise NotImplementedError(cfg.lr_scheduler)
return lr
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Linear warm up learning rate"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" Cosine annealing learning rate"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get parameters for Momentum optimizer"""
def get_param_groups(network):
"""get parameters"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

View File

@ -20,14 +20,18 @@ python train.py --data_path /YourDataPath
import ast
import argparse
from src.config import alexnet_cfg as cfg
from src.dataset import create_dataset_cifar10
from src.generator_lr import get_lr
import os
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
from src.generator_lr import get_lr_cifar10, get_lr_imagenet
from src.alexnet import AlexNet
from src.get_param_groups import get_param_groups
import mindspore.nn as nn
from mindspore.communication.management import init, get_rank
from mindspore import context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed
@ -36,27 +40,102 @@ set_seed(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=ast.literal_eval,
default=True, help='dataset_sink_mode is False or True')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.dataset_name == "cifar10":
cfg = alexnet_cifar10_cfg
elif args.dataset_name == "imagenet":
cfg = alexnet_imagenet_cfg
else:
raise ValueError("Unsupport dataset.")
device_target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=False)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_target == "Ascend":
context.set_context(device_id=args.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
elif device_target == "GPU":
init()
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
raise ValueError("Unsupported platform.")
if args.dataset_name == "cifar10":
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, target=args.device_target)
elif args.dataset_name == "imagenet":
ds_train = create_dataset_imagenet(args.data_path, cfg.batch_size)
else:
raise ValueError("Unsupport dataset.")
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()})
loss_scale_manager = None
metrics = None
if args.dataset_name == 'cifar10':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr_cifar10(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum)
metrics = {"Accuracy": Accuracy()}
elif args.dataset_name == 'imagenet':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
lr = Tensor(get_lr_imagenet(cfg, ds_train.get_dataset_size()))
opt = nn.Momentum(params=get_param_groups(network),
learning_rate=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
if cfg.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
else:
raise ValueError("Unsupport dataset.")
if device_target == "Ascend":
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False,
loss_scale_manager=loss_scale_manager)
elif device_target == "GPU":
model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, loss_scale_manager=loss_scale_manager)
else:
raise ValueError("Unsupported platform.")
if device_num > 1:
ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank()))
else:
ckpt_save_dir = args.ckpt_path
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=ckpt_save_dir, config=config_ck)
print("============== Starting Training ==============")
model.train(cfg.epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],