forked from mindspore-Ecosystem/mindspore
!23005 fix mobilenetv2 bugs on GPU
Merge pull request !23005 from zhaoting/master
This commit is contained in:
commit
a4ee2145a6
|
@ -38,6 +38,7 @@ device_id: 0
|
||||||
rank_id: 0
|
rank_id: 0
|
||||||
rank_size: 1
|
rank_size: 1
|
||||||
run_distribute: False
|
run_distribute: False
|
||||||
|
run_eval: False
|
||||||
activation: "Softmax"
|
activation: "Softmax"
|
||||||
|
|
||||||
# Image classification trian. train_parse_args():return train_args
|
# Image classification trian. train_parse_args():return train_args
|
||||||
|
@ -86,6 +87,7 @@ file_name: "output file name."
|
||||||
result_path: "result files path."
|
result_path: "result files path."
|
||||||
label_path: "label path."
|
label_path: "label path."
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
run_eval: 'Whether run evaluation while training, default is false.'
|
||||||
run_distribute: 'Run distribute, default is false.'
|
run_distribute: 'Run distribute, default is false.'
|
||||||
device_id: 'Device id, default is 0.'
|
device_id: 'Device id, default is 0.'
|
||||||
rank_id: 'Rank id, default is 0.'
|
rank_id: 'Rank id, default is 0.'
|
||||||
|
|
|
@ -38,6 +38,7 @@ device_id: 0
|
||||||
rank_id: 0
|
rank_id: 0
|
||||||
rank_size: 1
|
rank_size: 1
|
||||||
run_distribute: False
|
run_distribute: False
|
||||||
|
run_eval: True
|
||||||
activation: "Softmax"
|
activation: "Softmax"
|
||||||
|
|
||||||
# Image classification trian. train_parse_args():return train_args
|
# Image classification trian. train_parse_args():return train_args
|
||||||
|
@ -86,6 +87,7 @@ file_name: "output file name."
|
||||||
result_path: "result files path."
|
result_path: "result files path."
|
||||||
label_path: "label path."
|
label_path: "label path."
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
run_eval: 'Whether run evaluation while training, default is false.'
|
||||||
run_distribute: 'Run distribute, default is false.'
|
run_distribute: 'Run distribute, default is false.'
|
||||||
device_id: 'Device id, default is 0.'
|
device_id: 'Device id, default is 0.'
|
||||||
rank_id: 'Rank id, default is 0.'
|
rank_id: 'Rank id, default is 0.'
|
||||||
|
|
|
@ -34,7 +34,11 @@ save_checkpoint_epochs: 1
|
||||||
keep_checkpoint_max: 20
|
keep_checkpoint_max: 20
|
||||||
save_checkpoint_path: "./"
|
save_checkpoint_path: "./"
|
||||||
platform: 'CPU'
|
platform: 'CPU'
|
||||||
|
device_id: 0
|
||||||
|
rank_id: 0
|
||||||
|
rank_size: 1
|
||||||
run_distribute: False
|
run_distribute: False
|
||||||
|
run_eval: False
|
||||||
activation: "Softmax"
|
activation: "Softmax"
|
||||||
|
|
||||||
# Image classification trian. train_parse_args():return train_args
|
# Image classification trian. train_parse_args():return train_args
|
||||||
|
@ -83,6 +87,7 @@ file_name: "output file name."
|
||||||
result_path: "result files path."
|
result_path: "result files path."
|
||||||
label_path: "label path."
|
label_path: "label path."
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
run_eval: 'Whether run evaluation while training, default is false.'
|
||||||
run_distribute: 'Run distribute, default is false.'
|
run_distribute: 'Run distribute, default is false.'
|
||||||
device_id: 'Device id, default is 0.'
|
device_id: 'Device id, default is 0.'
|
||||||
rank_id: 'Rank id, default is 0.'
|
rank_id: 'Rank id, default is 0.'
|
||||||
|
|
|
@ -34,6 +34,9 @@ save_checkpoint_epochs: 1
|
||||||
keep_checkpoint_max: 200
|
keep_checkpoint_max: 200
|
||||||
save_checkpoint_path: "./"
|
save_checkpoint_path: "./"
|
||||||
platform: 'GPU'
|
platform: 'GPU'
|
||||||
|
device_id: 0
|
||||||
|
rank_id: 0
|
||||||
|
rank_size: 1
|
||||||
run_distribute: True
|
run_distribute: True
|
||||||
activation: "Softmax"
|
activation: "Softmax"
|
||||||
|
|
||||||
|
@ -57,6 +60,7 @@ ckpt_file: "/cache/train/mobilenetv2-200_625.ckpt"
|
||||||
file_name: "mobilenetv2"
|
file_name: "mobilenetv2"
|
||||||
file_format: "MINDIR"
|
file_format: "MINDIR"
|
||||||
is_training_export: False
|
is_training_export: False
|
||||||
|
run_eval: False
|
||||||
run_distribute_export: False
|
run_distribute_export: False
|
||||||
|
|
||||||
# postprocess.py / mobilenetv2 acc calculation
|
# postprocess.py / mobilenetv2 acc calculation
|
||||||
|
@ -83,6 +87,7 @@ file_name: "output file name."
|
||||||
result_path: "result files path."
|
result_path: "result files path."
|
||||||
label_path: "label path."
|
label_path: "label path."
|
||||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
run_eval: 'Whether run evaluation while training, default is false.'
|
||||||
run_distribute: 'Run distribute, default is false.'
|
run_distribute: 'Run distribute, default is false.'
|
||||||
device_id: 'Device id, default is 0.'
|
device_id: 'Device id, default is 0.'
|
||||||
rank_id: 'Rank id, default is 0.'
|
rank_id: 'Rank id, default is 0.'
|
||||||
|
|
|
@ -19,21 +19,14 @@ import time
|
||||||
import os
|
import os
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.models import define_net, load_ckpt
|
from src.models import define_net, load_ckpt
|
||||||
from src.utils import switch_precision, set_context
|
from src.utils import context_device_init
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||||
|
|
||||||
|
|
||||||
config.is_training = config.is_training_eval
|
config.is_training = config.is_training_eval
|
||||||
config.device_id = get_device_id()
|
|
||||||
config.rank_id = get_rank_id()
|
|
||||||
config.rank_size = get_device_num()
|
|
||||||
config.run_distribute = config.rank_size > 1.
|
|
||||||
|
|
||||||
def modelarts_process():
|
def modelarts_process():
|
||||||
""" modelarts process """
|
""" modelarts process """
|
||||||
|
@ -96,13 +89,13 @@ def modelarts_process():
|
||||||
def eval_mobilenetv2():
|
def eval_mobilenetv2():
|
||||||
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
||||||
print('\nconfig: \n', config)
|
print('\nconfig: \n', config)
|
||||||
set_context(config)
|
if not config.device_id:
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
context_device_init(config)
|
||||||
_, _, net = define_net(config, config.is_training)
|
_, _, net = define_net(config, config.is_training)
|
||||||
|
|
||||||
load_ckpt(net, config.pretrain_ckpt)
|
load_ckpt(net, config.pretrain_ckpt)
|
||||||
|
|
||||||
switch_precision(net, mstype.float16, config)
|
|
||||||
|
|
||||||
dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, config=config)
|
dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, config=config)
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
if step_size == 0:
|
if step_size == 0:
|
||||||
|
|
|
@ -16,25 +16,16 @@
|
||||||
mobilenetv2 export file.
|
mobilenetv2 export file.
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import Tensor, export, context
|
from mindspore import Tensor, export
|
||||||
from src.models import define_net, load_ckpt
|
from src.models import define_net, load_ckpt
|
||||||
from src.utils import set_context
|
from src.utils import context_device_init
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
from src.model_utils.device_adapter import get_device_id
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
|
|
||||||
|
|
||||||
config.device_id = get_device_id()
|
|
||||||
config.rank_id = get_rank_id()
|
|
||||||
config.rank_size = get_device_num()
|
|
||||||
config.run_distribute = config.rank_size > 1.
|
|
||||||
config.batch_size = config.batch_size_export
|
config.batch_size = config.batch_size_export
|
||||||
config.is_training = config.is_training_export
|
config.is_training = config.is_training_export
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform)
|
|
||||||
if config.platform == "Ascend":
|
|
||||||
context.set_context(device_id=get_device_id())
|
|
||||||
|
|
||||||
def modelarts_process():
|
def modelarts_process():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -42,7 +33,9 @@ def modelarts_process():
|
||||||
def export_mobilenetv2():
|
def export_mobilenetv2():
|
||||||
""" export_mobilenetv2 """
|
""" export_mobilenetv2 """
|
||||||
print('\nconfig: \n', config)
|
print('\nconfig: \n', config)
|
||||||
set_context(config)
|
if not config.device_id:
|
||||||
|
config.device_id = get_device_id()
|
||||||
|
context_device_init(config)
|
||||||
_, _, net = define_net(config, config.is_training)
|
_, _, net = define_net(config, config.is_training)
|
||||||
|
|
||||||
load_ckpt(net, config.ckpt_file)
|
load_ckpt(net, config.ckpt_file)
|
||||||
|
|
|
@ -46,7 +46,10 @@ run_ascend()
|
||||||
echo "error: DATASET_PATH=$6 is not a directory or file"
|
echo "error: DATASET_PATH=$6 is not a directory or file"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
RUN_DISTRIBUTE=True
|
||||||
|
if [ $2 -eq 1 ] ; then
|
||||||
|
RUN_DISTRIBUTE=False
|
||||||
|
fi
|
||||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
CONFIG_FILE="${BASEPATH}/../$2"
|
CONFIG_FILE="${BASEPATH}/../$2"
|
||||||
|
|
||||||
|
@ -85,6 +88,7 @@ run_ascend()
|
||||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
env > env.log
|
env > env.log
|
||||||
taskset -c $cmdopt python train.py \
|
taskset -c $cmdopt python train.py \
|
||||||
|
--run_distribute=$RUN_DISTRIBUTE \
|
||||||
--config_path=$CONFIG_FILE \
|
--config_path=$CONFIG_FILE \
|
||||||
--platform=$1 \
|
--platform=$1 \
|
||||||
--dataset_path=$6 \
|
--dataset_path=$6 \
|
||||||
|
|
|
@ -49,31 +49,8 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1, enable_cache=Fa
|
||||||
nfs_dataset_cache = None
|
nfs_dataset_cache = None
|
||||||
|
|
||||||
num_workers = config.num_workers
|
num_workers = config.num_workers
|
||||||
if config.platform == "Ascend":
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=do_train,
|
||||||
rank_size = int(os.getenv("RANK_SIZE", '1'))
|
num_shards=config.rank_size, shard_id=config.rank_id, cache=nfs_dataset_cache)
|
||||||
rank_id = int(os.getenv("RANK_ID", '0'))
|
|
||||||
if rank_size == 1:
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
|
|
||||||
cache=nfs_dataset_cache)
|
|
||||||
else:
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
|
|
||||||
num_shards=rank_size, shard_id=rank_id, cache=nfs_dataset_cache)
|
|
||||||
elif config.platform == "GPU":
|
|
||||||
if do_train:
|
|
||||||
if config.run_distribute:
|
|
||||||
from mindspore.communication.management import get_rank, get_group_size
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
|
|
||||||
num_shards=get_group_size(), shard_id=get_rank(),
|
|
||||||
cache=nfs_dataset_cache)
|
|
||||||
else:
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
|
|
||||||
cache=nfs_dataset_cache)
|
|
||||||
else:
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, shuffle=True,
|
|
||||||
cache=nfs_dataset_cache)
|
|
||||||
elif config.platform == "CPU":
|
|
||||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_workers, \
|
|
||||||
shuffle=True, cache=nfs_dataset_cache)
|
|
||||||
|
|
||||||
resize_height = config.image_height
|
resize_height = config.image_height
|
||||||
resize_width = config.image_width
|
resize_width = config.image_width
|
||||||
|
|
|
@ -42,14 +42,16 @@ class ClassifyCorrectCell(nn.Cell):
|
||||||
>>> eval_net = nn.ClassifyCorrectCell(net)
|
>>> eval_net = nn.ClassifyCorrectCell(net)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network):
|
def __init__(self, network, run_distribute):
|
||||||
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
|
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
|
||||||
self._network = network
|
self._network = network
|
||||||
self.argmax = P.Argmax()
|
self.argmax = P.Argmax()
|
||||||
self.equal = P.Equal()
|
self.equal = P.Equal()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
self.reduce_sum = P.ReduceSum()
|
self.reduce_sum = P.ReduceSum()
|
||||||
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
self.run_distribute = run_distribute
|
||||||
|
if run_distribute:
|
||||||
|
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||||
|
|
||||||
def construct(self, data, label):
|
def construct(self, data, label):
|
||||||
outputs = self._network(data)
|
outputs = self._network(data)
|
||||||
|
@ -58,8 +60,9 @@ class ClassifyCorrectCell(nn.Cell):
|
||||||
y_correct = self.equal(y_pred, label)
|
y_correct = self.equal(y_pred, label)
|
||||||
y_correct = self.cast(y_correct, mstype.float32)
|
y_correct = self.cast(y_correct, mstype.float32)
|
||||||
y_correct = self.reduce_sum(y_correct)
|
y_correct = self.reduce_sum(y_correct)
|
||||||
total_correct = self.allreduce(y_correct)
|
if self.run_distribute:
|
||||||
return (total_correct,)
|
y_correct = self.allreduce(y_correct)
|
||||||
|
return (y_correct,)
|
||||||
|
|
||||||
|
|
||||||
class DistAccuracy(nn.Metric):
|
class DistAccuracy(nn.Metric):
|
||||||
|
|
|
@ -14,72 +14,40 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import nn
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
from mindspore.communication.management import get_rank, init, get_group_size
|
from mindspore.communication.management import get_rank, init, get_group_size
|
||||||
|
|
||||||
from src.models import Monitor
|
from src.models import Monitor
|
||||||
|
|
||||||
|
|
||||||
def switch_precision(net, data_type, config):
|
|
||||||
if config.platform == "Ascend":
|
|
||||||
net.to_float(data_type)
|
|
||||||
for _, cell in net.cells_and_names():
|
|
||||||
if isinstance(cell, nn.Dense):
|
|
||||||
cell.to_float(mstype.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def context_device_init(config):
|
def context_device_init(config):
|
||||||
|
if config.platform == "GPU" and config.run_distribute:
|
||||||
|
config.device_id = 0
|
||||||
|
config.rank_id = 0
|
||||||
|
config.rank_size = 1
|
||||||
if config.platform == "CPU":
|
if config.platform == "CPU":
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
|
||||||
|
|
||||||
elif config.platform == "GPU":
|
elif config.platform in ["Ascend", "GPU"]:
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
|
|
||||||
if config.run_distribute:
|
|
||||||
init()
|
|
||||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
||||||
gradients_mean=True)
|
|
||||||
|
|
||||||
elif config.platform == "Ascend":
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
|
||||||
save_graphs=False)
|
save_graphs=False)
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
|
init()
|
||||||
|
config.rank_id = get_rank()
|
||||||
|
config.rank_size = get_group_size()
|
||||||
context.set_auto_parallel_context(device_num=config.rank_size,
|
context.set_auto_parallel_context(device_num=config.rank_size,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True, all_reduce_fusion_config=[140])
|
gradients_mean=True)
|
||||||
init()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only support CPU, GPU and Ascend.")
|
raise ValueError("Only support CPU, GPU and Ascend.")
|
||||||
|
|
||||||
|
|
||||||
def set_context(config):
|
|
||||||
if config.platform == "CPU":
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
|
|
||||||
save_graphs=False)
|
|
||||||
elif config.platform == "Ascend":
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
|
|
||||||
device_id=config.device_id, save_graphs=False)
|
|
||||||
elif config.platform == "GPU":
|
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
|
||||||
device_target=config.platform, save_graphs=False)
|
|
||||||
|
|
||||||
|
|
||||||
def config_ckpoint(config, lr, step_size, model=None, eval_dataset=None):
|
def config_ckpoint(config, lr, step_size, model=None, eval_dataset=None):
|
||||||
cb = [Monitor(lr_init=lr.asnumpy(), model=model, eval_dataset=eval_dataset)]
|
cb = [Monitor(lr_init=lr.asnumpy(), model=model, eval_dataset=eval_dataset)]
|
||||||
if config.platform in ("CPU", "GPU") or config.rank_id == 0:
|
if config.save_checkpoint and config.rank_id == 0:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||||
if config.save_checkpoint:
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(config.rank_id) + "/"
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck)
|
||||||
|
cb += [ckpt_cb]
|
||||||
rank = 0
|
|
||||||
if config.run_distribute:
|
|
||||||
rank = get_rank()
|
|
||||||
|
|
||||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/"
|
|
||||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck)
|
|
||||||
cb += [ckpt_cb]
|
|
||||||
return cb
|
return cb
|
||||||
|
|
|
@ -24,7 +24,6 @@ from mindspore import Tensor
|
||||||
from mindspore.nn import WithLossCell, TrainOneStepCell
|
from mindspore.nn import WithLossCell, TrainOneStepCell
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.communication.management import get_rank
|
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import save_checkpoint
|
from mindspore.train.serialization import save_checkpoint
|
||||||
|
@ -37,7 +36,7 @@ from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt
|
||||||
from src.metric import DistAccuracy, ClassifyCorrectCell
|
from src.metric import DistAccuracy, ClassifyCorrectCell
|
||||||
from src.model_utils.config import config
|
from src.model_utils.config import config
|
||||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||||
|
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
@ -116,23 +115,16 @@ def build_params_groups(net):
|
||||||
def train_mobilenetv2():
|
def train_mobilenetv2():
|
||||||
config.train_dataset_path = os.path.join(config.dataset_path, 'train')
|
config.train_dataset_path = os.path.join(config.dataset_path, 'train')
|
||||||
config.eval_dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
config.eval_dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
|
||||||
|
if not config.device_id:
|
||||||
config.device_id = get_device_id()
|
config.device_id = get_device_id()
|
||||||
config.rank_id = get_rank_id()
|
|
||||||
config.rank_size = get_device_num()
|
|
||||||
if config.platform == 'Ascend':
|
|
||||||
config.run_distribute = config.rank_size > 1.
|
|
||||||
|
|
||||||
print('\nconfig: {} \n'.format(config))
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
# set context and device init
|
# set context and device init
|
||||||
context_device_init(config)
|
context_device_init(config)
|
||||||
|
print('\nconfig: {} \n'.format(config))
|
||||||
# define network
|
# define network
|
||||||
backbone_net, head_net, net = define_net(config, config.is_training)
|
backbone_net, head_net, net = define_net(config, config.is_training)
|
||||||
dataset = create_dataset(dataset_path=config.train_dataset_path, do_train=True, config=config,
|
dataset = create_dataset(dataset_path=config.train_dataset_path, do_train=True, config=config,
|
||||||
enable_cache=config.enable_cache, cache_session_id=config.cache_session_id)
|
enable_cache=config.enable_cache, cache_session_id=config.cache_session_id)
|
||||||
eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False, config=config)
|
|
||||||
step_size = dataset.get_dataset_size()
|
step_size = dataset.get_dataset_size()
|
||||||
if config.platform == "GPU":
|
if config.platform == "GPU":
|
||||||
context.set_context(enable_graph_kernel=True)
|
context.set_context(enable_graph_kernel=True)
|
||||||
|
@ -165,23 +157,27 @@ def train_mobilenetv2():
|
||||||
warmup_epochs=config.warmup_epochs,
|
warmup_epochs=config.warmup_epochs,
|
||||||
total_epochs=epoch_size,
|
total_epochs=epoch_size,
|
||||||
steps_per_epoch=step_size))
|
steps_per_epoch=step_size))
|
||||||
|
metrics = {"acc"}
|
||||||
|
dist_eval_network = None
|
||||||
|
eval_dataset = None
|
||||||
|
if config.run_eval:
|
||||||
|
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.rank_size)}
|
||||||
|
dist_eval_network = ClassifyCorrectCell(net, config.run_distribute)
|
||||||
|
eval_dataset = create_dataset(dataset_path=config.eval_dataset_path, do_train=False, config=config)
|
||||||
if config.pretrain_ckpt == "" or config.freeze_layer != "backbone":
|
if config.pretrain_ckpt == "" or config.freeze_layer != "backbone":
|
||||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
if config.platform == "Ascend":
|
||||||
group_params = build_params_groups(net)
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
group_params = build_params_groups(net)
|
||||||
|
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||||
metrics = {"acc"}
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
|
||||||
dist_eval_network = None
|
metrics=metrics, eval_network=dist_eval_network,
|
||||||
if config.run_distribute:
|
amp_level="O2", keep_batchnorm_fp32=False,
|
||||||
metrics = {'acc': DistAccuracy(batch_size=config.batch_size, device_num=config.rank_size)}
|
acc_level=config.acc_mode)
|
||||||
dist_eval_network = ClassifyCorrectCell(net)
|
|
||||||
|
|
||||||
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
|
|
||||||
metrics=metrics, eval_network=dist_eval_network,
|
|
||||||
amp_level="O2", keep_batchnorm_fp32=False,
|
|
||||||
acc_level=config.acc_mode)
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, eval_network=dist_eval_network,
|
||||||
|
acc_level=config.acc_mode)
|
||||||
cb = config_ckpoint(config, lr, step_size, model, eval_dataset)
|
cb = config_ckpoint(config, lr, step_size, model, eval_dataset)
|
||||||
print("============== Starting Training ==============")
|
print("============== Starting Training ==============")
|
||||||
model.train(epoch_size, dataset, callbacks=cb)
|
model.train(epoch_size, dataset, callbacks=cb)
|
||||||
|
@ -197,9 +193,7 @@ def train_mobilenetv2():
|
||||||
|
|
||||||
features_path = config.train_dataset_path + '_features'
|
features_path = config.train_dataset_path + '_features'
|
||||||
idx_list = list(range(step_size))
|
idx_list = list(range(step_size))
|
||||||
rank = 0
|
rank = config.rank_id
|
||||||
if config.run_distribute:
|
|
||||||
rank = get_rank()
|
|
||||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||||
if not os.path.isdir(save_ckpt_path):
|
if not os.path.isdir(save_ckpt_path):
|
||||||
os.mkdir(save_ckpt_path)
|
os.mkdir(save_ckpt_path)
|
||||||
|
|
Loading…
Reference in New Issue