From d9254903010f42f7ffbeb357d916a8df38744afb Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Thu, 18 Mar 2021 17:22:02 +0800 Subject: [PATCH] use same network in TrainOneStepCell --- .../official/cv/ctpn/src/network_define.py | 22 +++++-------------- model_zoo/official/cv/ctpn/train.py | 4 ++-- .../cv/deeptext/src/network_define.py | 2 +- .../cv/faster_rcnn/src/network_define.py | 2 +- .../cv/maskrcnn/src/network_define.py | 2 +- .../src/network_define.py | 2 +- 6 files changed, 12 insertions(+), 22 deletions(-) diff --git a/model_zoo/official/cv/ctpn/src/network_define.py b/model_zoo/official/cv/ctpn/src/network_define.py index 2ab20e8aa2d..e1458bdbac0 100644 --- a/model_zoo/official/cv/ctpn/src/network_define.py +++ b/model_zoo/official/cv/ctpn/src/network_define.py @@ -46,8 +46,6 @@ class LossCallBack(Callback): self._per_print_times = per_print_times self.count = 0 self.rpn_loss_sum = 0 - self.rpn_cls_loss_sum = 0 - self.rpn_reg_loss_sum = 0 self.rank_id = rank_id global time_stamp_init, time_stamp_first @@ -57,14 +55,10 @@ class LossCallBack(Callback): def step_end(self, run_context): cb_params = run_context.original_args() - rpn_loss = cb_params.net_outputs[0].asnumpy() - rpn_cls_loss = cb_params.net_outputs[1].asnumpy() - rpn_reg_loss = cb_params.net_outputs[2].asnumpy() + rpn_loss = cb_params.net_outputs.asnumpy() self.count += 1 self.rpn_loss_sum += float(rpn_loss) - self.rpn_cls_loss_sum += float(rpn_cls_loss) - self.rpn_reg_loss_sum += float(rpn_reg_loss) cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 @@ -72,12 +66,10 @@ class LossCallBack(Callback): global time_stamp_first time_stamp_current = time.time() rpn_loss = self.rpn_loss_sum / self.count - rpn_cls_loss = self.rpn_cls_loss_sum / self.count - rpn_reg_loss = self.rpn_reg_loss_sum / self.count loss_file = open("./loss_{}.log".format(self.rank_id), "a+") - loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"% + loss_file.write("%lu epoch: %s step: %s rpn_loss: %.5f"% (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, - rpn_loss, rpn_cls_loss, rpn_reg_loss)) + rpn_loss)) loss_file.write("\n") loss_file.close() @@ -123,18 +115,16 @@ class TrainOneStepCell(nn.Cell): Args: network (Cell): The training network. - network_backbone (Cell): The forward network. optimizer (Cell): Optimizer for updating the weights. sens (Number): The adjust parameter. Default value is 1.0. reduce_flag (bool): The reduce flag. Default value is False. mean (bool): Allreduce method. Default value is False. degree (int): Device number. Default value is None. """ - def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): + def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): super(TrainOneStepCell, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() - self.backbone = network_backbone self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer self.grad = C.GradOperation(get_by_list=True, @@ -146,8 +136,8 @@ class TrainOneStepCell(nn.Cell): def construct(self, x, gt_bbox, gt_label, gt_num, img_shape=None): weights = self.weights - rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, gt_bbox, gt_label, gt_num, img_shape) + loss = self.network(x, gt_bbox, gt_label, gt_num, img_shape) grads = self.grad(self.network, weights)(x, gt_bbox, gt_label, gt_num, img_shape, self.sens) if self.reduce_flag: grads = self.grad_reducer(grads) - return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/official/cv/ctpn/train.py b/model_zoo/official/cv/ctpn/train.py index ab21edd72d2..f90d7f9a40f 100644 --- a/model_zoo/official/cv/ctpn/train.py +++ b/model_zoo/official/cv/ctpn/train.py @@ -100,10 +100,10 @@ if __name__ == '__main__': weight_decay=config.weight_decay, loss_scale=config.loss_scale) net_with_loss = WithLossCell(net, loss) if args_opt.run_distribute: - net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, + net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, mean=True, degree=device_num) else: - net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) + net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) time_cb = TimeMonitor(data_size=dataset_size) loss_cb = LossCallBack(rank_id=rank) diff --git a/model_zoo/official/cv/deeptext/src/network_define.py b/model_zoo/official/cv/deeptext/src/network_define.py index dc068fc129f..2fcd9bb6c44 100644 --- a/model_zoo/official/cv/deeptext/src/network_define.py +++ b/model_zoo/official/cv/deeptext/src/network_define.py @@ -69,7 +69,7 @@ class LossCallBack(Callback): total_loss = self.loss_sum / self.count loss_file = open("./loss_{}.log".format(self.rank_id), "a+") - loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % + loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" % (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss)) loss_file.write("\n") diff --git a/model_zoo/official/cv/faster_rcnn/src/network_define.py b/model_zoo/official/cv/faster_rcnn/src/network_define.py index 9c2b679d97b..531cd32c6e5 100644 --- a/model_zoo/official/cv/faster_rcnn/src/network_define.py +++ b/model_zoo/official/cv/faster_rcnn/src/network_define.py @@ -69,7 +69,7 @@ class LossCallBack(Callback): total_loss = self.loss_sum / self.count loss_file = open("./loss_{}.log".format(self.rank_id), "a+") - loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % + loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" % (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss)) loss_file.write("\n") diff --git a/model_zoo/official/cv/maskrcnn/src/network_define.py b/model_zoo/official/cv/maskrcnn/src/network_define.py index dcec62cd768..662cd99cefb 100644 --- a/model_zoo/official/cv/maskrcnn/src/network_define.py +++ b/model_zoo/official/cv/maskrcnn/src/network_define.py @@ -68,7 +68,7 @@ class LossCallBack(Callback): total_loss = self.loss_sum / self.count loss_file = open("./loss_{}.log".format(self.rank_id), "a+") - loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % + loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" % (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss)) loss_file.write("\n") diff --git a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py index ce72b2957ed..1bb4efa5e98 100644 --- a/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py +++ b/model_zoo/official/cv/maskrcnn_mobilenetv1/src/network_define.py @@ -97,7 +97,7 @@ class LossCallBack(Callback): total_loss = self.loss_sum/self.count loss_file = open("./loss_{}.log".format(self.rank_id), "a+") - loss_file.write("%lu epoch: %s step: %s ,total_loss: %.5f" % + loss_file.write("%lu epoch: %s step: %s total_loss: %.5f" % (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, total_loss)) loss_file.write("\n")