recover baseline and fix resnet_thor
This commit is contained in:
parent
518d9589ac
commit
74d5314ef0
|
@ -63,7 +63,7 @@ def thor_end():
|
|||
thor_cost /= 4
|
||||
print(f"resnet thor_loss: {thor_loss}, thor_cost: {thor_cost}")
|
||||
assert thor_loss < 7
|
||||
assert thor_cost < 40
|
||||
assert thor_cost < 30
|
||||
for i in range(4):
|
||||
shutil.rmtree(os.path.join(sh_path, f"train_parallel{i+4}"))
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def run_train():
|
|||
if device_num > 1:
|
||||
ms.communication.init()
|
||||
ms.context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, all_reduce_fusion_config=[85, 160])
|
||||
gradients_mean=True, all_reduce_fusion_config=[80, 160])
|
||||
net = resnet50(thor_config.class_num)
|
||||
|
||||
if not thor_config.label_smooth:
|
||||
|
|
Loading…
Reference in New Issue