recover baseline and fix resnet_thor

This commit is contained in:
王程浩 2022-06-14 06:30:32 +00:00 committed by cheng-hao-wang
parent 518d9589ac
commit 74d5314ef0
2 changed files with 2 additions and 2 deletions

View File

@ -63,7 +63,7 @@ def thor_end():
thor_cost /= 4
print(f"resnet thor_loss: {thor_loss}, thor_cost: {thor_cost}")
assert thor_loss < 7
assert thor_cost < 40
assert thor_cost < 30
for i in range(4):
shutil.rmtree(os.path.join(sh_path, f"train_parallel{i+4}"))

View File

@ -72,7 +72,7 @@ def run_train():
if device_num > 1:
ms.communication.init()
ms.context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[85, 160])
gradients_mean=True, all_reduce_fusion_config=[80, 160])
net = resnet50(thor_config.class_num)
if not thor_config.label_smooth: