fix hswishquant and hsigmoidquant validation false bug

This commit is contained in:
chenzomi 2020-07-09 09:10:24 +08:00
parent 3446940142
commit 2ff29f0198
3 changed files with 32 additions and 25 deletions

View File

@ -920,7 +920,7 @@ class HSwishQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
if isinstance(activation, nn.HSwish):
if issubclass(activation, nn.HSwish):
self.act = activation()
else:
raise ValueError("Activation should be `nn.HSwish`")
@ -989,7 +989,7 @@ class HSigmoidQuant(_QuantActivation):
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
if isinstance(activation, nn.HSwish):
if issubclass(activation, nn.HSwish):
self.act = activation()
else:
raise ValueError("Activation should be `nn.HSigmoid`")

View File

@ -18,6 +18,7 @@ import time
import argparse
import random
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_group_size
from mindspore.communication.management import init, get_group_size, get_rank
import mindspore.dataset.engine as de
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend
@ -60,9 +62,14 @@ if args_opt.platform == "Ascend":
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
device_target="GPU",
save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else:
raise ValueError("Unsupport platform.")
raise ValueError("Unsupported device target.")
class CrossEntropyWithLabelSmooth(_Loss):
@ -155,12 +162,8 @@ class Monitor(Callback):
if __name__ == '__main__':
if args_opt.platform == "GPU":
# train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
print("train args: ", args_opt)
print("cfg: ", config_gpu)
# define net
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
@ -201,13 +204,13 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(
prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# begine train
# begin train
model.train(epoch_size, dataset, callbacks=cb)
elif args_opt.platform == "Ascend":
# train on ascend

View File

@ -18,6 +18,7 @@ import time
import argparse
import random
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from mindspore.communication.management import init, get_group_size
from mindspore.communication.management import init, get_group_size, get_rank
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend
@ -57,10 +59,16 @@ if args_opt.platform == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
device_id=device_id,
save_graphs=False)
elif args_opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False)
device_target="GPU",
save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else:
raise ValueError("Unsupport platform.")
@ -155,12 +163,8 @@ class Monitor(Callback):
if __name__ == '__main__':
if args_opt.platform == "GPU":
# train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
print("train args: ", args_opt)
print("cfg: ", config_gpu)
# define net
net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
@ -201,11 +205,11 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(
prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# begine train
model.train(epoch_size, dataset, callbacks=cb)