!13526 modify network_define for fasterrcnn/maskrcnn/maskrcnn_mobilenetv/deeptext

From: @huangbingjian
Reviewed-by: @zh_qh,@zhunaipan
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-18 16:53:48 +08:00 committed by Gitee
commit eaecc83ec2
8 changed files with 52 additions and 206 deletions

View File

@ -47,12 +47,7 @@ class LossCallBack(Callback):
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
@ -62,54 +57,26 @@ class LossCallBack(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.loss_sum += float(loss)
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum / self.count
rcnn_loss = self.rcnn_loss_sum / self.count
rpn_cls_loss = self.rpn_cls_loss_sum / self.count
rpn_reg_loss = self.rpn_reg_loss_sum / self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum / self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum / self.count
total_loss = rpn_loss + rcnn_loss
total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" %
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, total_loss))
total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.loss_sum = 0
class LossNet(nn.Cell):
@ -157,7 +124,6 @@ class TrainOneStepCell(nn.Cell):
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
@ -165,11 +131,10 @@ class TrainOneStepCell(nn.Cell):
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
@ -181,8 +146,8 @@ class TrainOneStepCell(nn.Cell):
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6
return F.depend(loss, self.optimizer(grads))

View File

@ -120,10 +120,10 @@ if __name__ == '__main__':
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)

View File

@ -47,12 +47,7 @@ class LossCallBack(Callback):
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
@ -62,54 +57,26 @@ class LossCallBack(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.loss_sum += float(loss)
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count
rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
total_loss = rpn_loss + rcnn_loss
total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, total_loss: %.5f" %
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, total_loss))
total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.loss_sum = 0
class LossNet(nn.Cell):
@ -155,18 +122,16 @@ class TrainOneStepCell(nn.Cell):
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
@ -178,8 +143,8 @@ class TrainOneStepCell(nn.Cell):
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6 = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num)
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6
return F.depend(loss, self.optimizer(grads))

View File

@ -159,10 +159,10 @@ if __name__ == '__main__':
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)

View File

@ -46,13 +46,7 @@ class LossCallBack(Callback):
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
self.loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
@ -62,59 +56,26 @@ class LossCallBack(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
rcnn_mask_loss = cb_params.net_outputs[6].asnumpy()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
self.rcnn_mask_loss_sum += float(rcnn_mask_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.loss_sum += float(loss)
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count
rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
rcnn_mask_loss = self.rcnn_mask_loss_sum/self.count
total_loss = rpn_loss + rcnn_loss
total_loss = self.loss_sum / self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, rcnn_mask_loss: %.5f, "
"total_loss: %.5f" %
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_loss, total_loss))
total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
self.loss_sum = 0
class LossNet(nn.Cell):
"""MaskRcnn loss method"""
@ -159,18 +120,16 @@ class TrainOneStepCell(nn.Cell):
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
@ -183,10 +142,9 @@ class TrainOneStepCell(nn.Cell):
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self.backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6, loss7
return F.depend(loss, self.optimizer(grads))

View File

@ -124,10 +124,10 @@ if __name__ == '__main__':
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)

View File

@ -75,13 +75,7 @@ class LossCallBack(Callback):
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
self.loss_sum = 0
self.rank_id = rank_id
global time_stamp_init, time_stamp_first
@ -91,59 +85,26 @@ class LossCallBack(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
rpn_loss = cb_params.net_outputs[0].asnumpy()
rcnn_loss = cb_params.net_outputs[1].asnumpy()
rpn_cls_loss = cb_params.net_outputs[2].asnumpy()
rpn_reg_loss = cb_params.net_outputs[3].asnumpy()
rcnn_cls_loss = cb_params.net_outputs[4].asnumpy()
rcnn_reg_loss = cb_params.net_outputs[5].asnumpy()
rcnn_mask_loss = cb_params.net_outputs[6].asnumpy()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.count += 1
self.rpn_loss_sum += float(rpn_loss)
self.rcnn_loss_sum += float(rcnn_loss)
self.rpn_cls_loss_sum += float(rpn_cls_loss)
self.rpn_reg_loss_sum += float(rpn_reg_loss)
self.rcnn_cls_loss_sum += float(rcnn_cls_loss)
self.rcnn_reg_loss_sum += float(rcnn_reg_loss)
self.rcnn_mask_loss_sum += float(rcnn_mask_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
self.loss_sum += float(loss)
if self.count >= 1:
global time_stamp_first
time_stamp_current = time.time()
rpn_loss = self.rpn_loss_sum/self.count
rcnn_loss = self.rcnn_loss_sum/self.count
rpn_cls_loss = self.rpn_cls_loss_sum/self.count
rpn_reg_loss = self.rpn_reg_loss_sum/self.count
rcnn_cls_loss = self.rcnn_cls_loss_sum/self.count
rcnn_reg_loss = self.rcnn_reg_loss_sum/self.count
rcnn_mask_loss = self.rcnn_mask_loss_sum/self.count
total_loss = rpn_loss + rcnn_loss
total_loss = self.loss_sum/self.count
loss_file = open("./loss_{}.log".format(self.rank_id), "a+")
loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rcnn_loss: %.5f, rpn_cls_loss: %.5f, "
"rpn_reg_loss: %.5f, rcnn_cls_loss: %.5f, rcnn_reg_loss: %.5f, rcnn_mask_loss: %.5f, "
"total_loss: %.5f" %
loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" %
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
rpn_loss, rcnn_loss, rpn_cls_loss, rpn_reg_loss,
rcnn_cls_loss, rcnn_reg_loss, rcnn_mask_loss, total_loss))
total_loss))
loss_file.write("\n")
loss_file.close()
self.count = 0
self.rpn_loss_sum = 0
self.rcnn_loss_sum = 0
self.rpn_cls_loss_sum = 0
self.rpn_reg_loss_sum = 0
self.rcnn_cls_loss_sum = 0
self.rcnn_reg_loss_sum = 0
self.rcnn_mask_loss_sum = 0
self.loss_sum = 0
class LossNet(nn.Cell):
"""MaskRcnn loss method"""
@ -188,18 +149,16 @@ class TrainOneStepCell(nn.Cell):
Args:
network (Cell): The training network.
network_backbone (Cell): The forward network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
@ -212,10 +171,9 @@ class TrainOneStepCell(nn.Cell):
def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask):
weights = self.weights
loss1, loss2, loss3, loss4, loss5, loss6, loss7 = self.backbone(x, img_shape, gt_bboxe, gt_label,
gt_num, gt_mask)
loss = self.network(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask)
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, gt_mask, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return F.depend(loss1, self.optimizer(grads)), loss2, loss3, loss4, loss5, loss6, loss7
return F.depend(loss, self.optimizer(grads))

View File

@ -123,10 +123,10 @@ if __name__ == '__main__':
net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True,
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True,
mean=True, degree=device_num)
else:
net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale)
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank)