forked from mindspore-Ecosystem/mindspore
!833 remove hccl seeting in mobilenetv2 eval script
Merge pull request !833 from wandongdong/master
This commit is contained in:
commit
122c9bc7f0
|
@ -27,6 +27,7 @@ config = ed({
|
|||
"lr": 0.4,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
|
|
|
@ -53,8 +53,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
|
||||
# define map operations
|
||||
decode_op = C.Decode()
|
||||
resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.2, 1.0))
|
||||
horizontal_flip_op = C.RandomHorizontalFlip()
|
||||
resize_crop_op = C.RandomResizedCrop(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
|
||||
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
|
||||
|
||||
resize_op = C.Resize((256, 256))
|
||||
center_crop = C.CenterCrop(resize_width)
|
||||
|
|
|
@ -38,8 +38,6 @@ context.set_context(enable_loop_sink=True)
|
|||
context.set_context(enable_mem_reuse=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(enable_hccl=False)
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
||||
net = mobilenet_v2()
|
||||
|
||||
|
|
|
@ -28,6 +28,10 @@ from mindspore.model_zoo.mobilenet import mobilenet_v2
|
|||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
|
||||
|
@ -54,6 +58,35 @@ context.set_context(enable_task_sink=True)
|
|||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
|
||||
class CrossEntropyWithLabelSmooth(_Loss):
|
||||
"""
|
||||
CrossEntropyWith LabelSmooth.
|
||||
|
||||
Args:
|
||||
smooth_factor (float): smooth factor, default=0.
|
||||
num_classes (int): num classes
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropyWithLabelSmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1], self.on_value, self.off_value)
|
||||
out_loss = self.ce(logit, one_hot_label)
|
||||
out_loss = self.mean(out_loss, 0)
|
||||
return out_loss
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
|
@ -63,7 +96,7 @@ class Monitor(Callback):
|
|||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None.
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
|
@ -122,7 +155,10 @@ if __name__ == '__main__':
|
|||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.add_flags_recursive(fp32=True)
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
||||
if config.label_smooth > 0:
|
||||
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
||||
|
||||
print("train args: ", args_opt, "\ncfg: ", config,
|
||||
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
|
||||
|
|
Loading…
Reference in New Issue