forked from OSSInnovation/mindspore
fix hswishquant and hsigmoidquant validation false bug
This commit is contained in:
parent
3446940142
commit
2ff29f0198
|
@ -920,7 +920,7 @@ class HSwishQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if isinstance(activation, nn.HSwish):
|
||||
if issubclass(activation, nn.HSwish):
|
||||
self.act = activation()
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSwish`")
|
||||
|
@ -989,7 +989,7 @@ class HSigmoidQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if isinstance(activation, nn.HSwish):
|
||||
if issubclass(activation, nn.HSwish):
|
||||
self.act = activation()
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||
|
|
|
@ -18,6 +18,7 @@ import time
|
|||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
|
@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
import mindspore.dataset.engine as de
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config_gpu, config_ascend
|
||||
|
@ -60,9 +62,14 @@ if args_opt.platform == "Ascend":
|
|||
device_id=device_id, save_graphs=False)
|
||||
elif args_opt.platform == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU", save_graphs=False)
|
||||
device_target="GPU",
|
||||
save_graphs=False)
|
||||
init("nccl")
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
else:
|
||||
raise ValueError("Unsupport platform.")
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
|
||||
class CrossEntropyWithLabelSmooth(_Loss):
|
||||
|
@ -155,12 +162,8 @@ class Monitor(Callback):
|
|||
if __name__ == '__main__':
|
||||
if args_opt.platform == "GPU":
|
||||
# train on gpu
|
||||
print("train args: ", args_opt, "\ncfg: ", config_gpu)
|
||||
|
||||
init('nccl')
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel",
|
||||
mirror_mean=True,
|
||||
device_num=get_group_size())
|
||||
print("train args: ", args_opt)
|
||||
print("cfg: ", config_gpu)
|
||||
|
||||
# define net
|
||||
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
|
||||
|
@ -201,13 +204,13 @@ if __name__ == '__main__':
|
|||
loss_scale_manager=loss_scale)
|
||||
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
if config_gpu.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(
|
||||
prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# begine train
|
||||
# begin train
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
elif args_opt.platform == "Ascend":
|
||||
# train on ascend
|
||||
|
|
|
@ -18,6 +18,7 @@ import time
|
|||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
|
@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
|||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.communication.management import init, get_group_size
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config_gpu, config_ascend
|
||||
|
@ -57,10 +59,16 @@ if args_opt.platform == "Ascend":
|
|||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id, save_graphs=False)
|
||||
device_id=device_id,
|
||||
save_graphs=False)
|
||||
elif args_opt.platform == "GPU":
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU", save_graphs=False)
|
||||
device_target="GPU",
|
||||
save_graphs=False)
|
||||
init("nccl")
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
else:
|
||||
raise ValueError("Unsupport platform.")
|
||||
|
||||
|
@ -155,12 +163,8 @@ class Monitor(Callback):
|
|||
if __name__ == '__main__':
|
||||
if args_opt.platform == "GPU":
|
||||
# train on gpu
|
||||
print("train args: ", args_opt, "\ncfg: ", config_gpu)
|
||||
|
||||
init('nccl')
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel",
|
||||
mirror_mean=True,
|
||||
device_num=get_group_size())
|
||||
print("train args: ", args_opt)
|
||||
print("cfg: ", config_gpu)
|
||||
|
||||
# define net
|
||||
net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
|
||||
|
@ -201,11 +205,11 @@ if __name__ == '__main__':
|
|||
loss_scale_manager=loss_scale)
|
||||
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
if config_gpu.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(
|
||||
prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck)
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# begine train
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
|
|
Loading…
Reference in New Issue