forked from OSSInnovation/mindspore
!2553 add mindrecord to mobilenetv2_quant && resnet50_quant
Merge pull request !2553 from wandongdong/r0.3
This commit is contained in:
commit
8c30045178
|
@ -22,7 +22,7 @@ from mindspore import nn
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.mobilenetV2_quant import mobilenet_v2_quant
|
||||
from src.dataset import create_dataset_py
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_ascend
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
|
@ -46,11 +46,11 @@ if __name__ == '__main__':
|
|||
loss = nn.SoftmaxCrossEntropyWithLogits(
|
||||
is_grad=False, sparse=True, reduction='mean')
|
||||
|
||||
dataset = create_dataset_py(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
config=config_platform,
|
||||
platform=args_opt.platform,
|
||||
batch_size=config_platform.batch_size)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
config=config_platform,
|
||||
platform=args_opt.platform,
|
||||
batch_size=config_platform.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
|
|
|
@ -22,6 +22,7 @@ config_ascend = ed({
|
|||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 192,
|
||||
"data_load_mode": "mindrecord",
|
||||
"epoch_size": 60,
|
||||
"start_epoch": 200,
|
||||
"warmup_epochs": 1,
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
from functools import partial
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
||||
from src.config import config_ascend
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
|
@ -38,14 +41,19 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
|
|||
if platform == "Ascend":
|
||||
rank_size = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
columns_list = ['image', 'label']
|
||||
if config_ascend.data_load_mode == "mindrecord":
|
||||
load_func = partial(de.MindDataset, dataset_path, columns_list)
|
||||
else:
|
||||
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
|
||||
if do_train:
|
||||
if rank_size == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
ds = load_func(num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=rank_size, shard_id=rank_id)
|
||||
ds = load_func(num_parallel_workers=8, shuffle=True,
|
||||
num_shards=rank_size, shard_id=rank_id)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
|
||||
ds = load_func(num_parallel_workers=8, shuffle=False)
|
||||
else:
|
||||
raise ValueError("Unsupport platform.")
|
||||
|
||||
|
@ -63,7 +71,8 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
|
|||
|
||||
resize_op = C.Resize(256)
|
||||
center_crop = C.CenterCrop(resize_height)
|
||||
normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
|
||||
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
||||
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
|
||||
change_swap_op = C.HWC2CHW()
|
||||
|
||||
if do_train:
|
||||
|
@ -84,6 +93,7 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch
|
|||
|
||||
return ds
|
||||
|
||||
|
||||
def create_dataset_py(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
|
|
@ -32,7 +32,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
|||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.dataset.engine as de
|
||||
from src.dataset import create_dataset_py
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config_ascend
|
||||
from src.mobilenetV2_quant import mobilenet_v2_quant
|
||||
|
@ -197,12 +197,12 @@ if __name__ == '__main__':
|
|||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(
|
||||
is_grad=False, sparse=True, reduction='mean')
|
||||
dataset = create_dataset_py(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
config=config_ascend,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
batch_size=config_ascend.batch_size)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
config=config_ascend,
|
||||
platform=args_opt.platform,
|
||||
repeat_num=epoch_size,
|
||||
batch_size=config_ascend.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
|
|
|
@ -17,7 +17,7 @@ eval.
|
|||
"""
|
||||
import os
|
||||
import argparse
|
||||
from src.dataset import create_dataset_py
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config
|
||||
from src.crossentropy import CrossEntropy
|
||||
from src.utils import _load_param_into_net
|
||||
|
@ -49,8 +49,8 @@ if __name__ == '__main__':
|
|||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
if args_opt.do_eval:
|
||||
dataset = create_dataset_py(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||
target=target)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
|
||||
target=target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
if args_opt.checkpoint_path:
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore import Tensor
|
|||
from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = False
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ config = ed({
|
|||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"data_load_mode": "mindrecord",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 50,
|
||||
|
|
|
@ -16,12 +16,15 @@
|
|||
create train or eval dataset.
|
||||
"""
|
||||
import os
|
||||
from functools import partial
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.py_transforms as P
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from src.config import config
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
|
@ -45,11 +48,16 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
columns_list = ['image', 'label']
|
||||
if config.data_load_mode == "mindrecord":
|
||||
load_func = partial(de.MindDataset, dataset_path, columns_list)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
|
||||
if device_num == 1:
|
||||
ds = load_func(num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = load_func(num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
|
@ -66,7 +74,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize((256, 256)),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
|
@ -85,6 +93,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
|
||||
return ds
|
||||
|
||||
|
||||
def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.communication.management import init
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from models.resnet_quant import resnet50_quant
|
||||
from src.dataset import create_dataset_py
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
@ -62,7 +62,6 @@ if __name__ == '__main__':
|
|||
epoch_size = config.epoch_size
|
||||
net = resnet50_quant(class_num=config.class_num)
|
||||
net.set_train(True)
|
||||
print("========resnet50:\r\n{}".format(net))
|
||||
|
||||
# weight init
|
||||
if args_opt.pre_trained:
|
||||
|
@ -85,8 +84,8 @@ if __name__ == '__main__':
|
|||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
if args_opt.do_train:
|
||||
dataset = create_dataset_py(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size, target=target)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
||||
repeat_num=epoch_size, batch_size=config.batch_size, target=target)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
@ -94,6 +93,7 @@ if __name__ == '__main__':
|
|||
total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine')
|
||||
if args_opt.pre_trained:
|
||||
lr = lr[config.pretrained_epoch_size * step_size:]
|
||||
print("========resnet50:\r\n{}\r\nlr: \r\n{}".format(net, lr))
|
||||
lr = Tensor(lr)
|
||||
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
|
@ -104,7 +104,7 @@ if __name__ == '__main__':
|
|||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
|
Loading…
Reference in New Issue