From 74d5314ef08a40e39ba19b7f18636fe3e372ae01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A8=8B=E6=B5=A9?= Date: Tue, 14 Jun 2022 06:30:32 +0000 Subject: [PATCH] recover baseline and fix resnet_thor --- .../networks/models/resnet50/test_resnet50_imagenet_and_thor.py | 2 +- tests/st/networks/models/resnet50/train_resnet50_thor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet_and_thor.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet_and_thor.py index c521316e63d..9cf16bb9f02 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet_and_thor.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet_and_thor.py @@ -63,7 +63,7 @@ def thor_end(): thor_cost /= 4 print(f"resnet thor_loss: {thor_loss}, thor_cost: {thor_cost}") assert thor_loss < 7 - assert thor_cost < 40 + assert thor_cost < 30 for i in range(4): shutil.rmtree(os.path.join(sh_path, f"train_parallel{i+4}")) diff --git a/tests/st/networks/models/resnet50/train_resnet50_thor.py b/tests/st/networks/models/resnet50/train_resnet50_thor.py index badf7c57418..748f597c78f 100644 --- a/tests/st/networks/models/resnet50/train_resnet50_thor.py +++ b/tests/st/networks/models/resnet50/train_resnet50_thor.py @@ -72,7 +72,7 @@ def run_train(): if device_num > 1: ms.communication.init() ms.context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, - gradients_mean=True, all_reduce_fusion_config=[85, 160]) + gradients_mean=True, all_reduce_fusion_config=[80, 160]) net = resnet50(thor_config.class_num) if not thor_config.label_smooth: