!6358 fix loss print in bert, change epoch number to integer

Merge pull request !6358 from chenhaozhe/fix-loss-print
This commit is contained in:
mindspore-ci-bot 2020-09-17 20:27:52 +08:00 committed by Gitee
commit cc4859f0b2
2 changed files with 12 additions and 1 deletions

View File

@ -381,6 +381,11 @@ epoch: 0.0, current epoch percent: 0.000, step: 2, outpus are (Tensor(shape=[1],
> ```
> This will extend the timeout limits of hccl from the default 120 seconds to 600 seconds.
> **Attention** If you are running with a big bert model, some error of protobuf may occurs while saving checkpoints, try with the following environ set.
> ```
> export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
> ```
### Distributed Training
#### Running on Ascend
```

View File

@ -145,11 +145,17 @@ class LossCallBack(Callback):
super(LossCallBack, self).__init__()
self._dataset_size = dataset_size
def step_end(self, run_context):
"""
Print loss after each step
"""
cb_params = run_context.original_args()
if self._dataset_size > 0:
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
if percent == 0:
percent = 1
epoch_num -= 1
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(epoch_num, "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
else:
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs)))