forked from mindspore-Ecosystem/mindspore
!31679 If the GPU memory is set to true, the test case fails.
Merge pull request !31679 from liuchuting/testcase
This commit is contained in:
commit
3b73182032
|
@ -685,15 +685,8 @@ class Profiler:
|
|||
logger.warning(err.message)
|
||||
|
||||
logger.warning(
|
||||
'\nMemory Usage is not supported on GPU currently.\n'
|
||||
'Please running on Ascend if you would like to see memory analysis, '
|
||||
'otherwise, this warning can be ignored.'
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
'\nProfile communication is not supported on GPU currently.\n'
|
||||
'Please running on Ascend if you would like to see cluster communication analysis, '
|
||||
'otherwise, this warning can be ignored.'
|
||||
'\nThe training and inference process does not support profiler currently, '
|
||||
'only individual training or inference is supported.'
|
||||
)
|
||||
|
||||
def _get_step_reduce_op_type(self):
|
||||
|
|
|
@ -151,7 +151,7 @@ class TestProfiler:
|
|||
def test_cpu_profiler(self):
|
||||
if sys.platform != 'linux':
|
||||
return
|
||||
self._train_with_profiler(device_target="CPU")
|
||||
self._train_with_profiler(device_target="CPU", profile_memory=False)
|
||||
self._check_cpu_profiling_file()
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -159,7 +159,7 @@ class TestProfiler:
|
|||
@pytest.mark.env_onecard
|
||||
@security_off_wrap
|
||||
def test_gpu_profiler(self):
|
||||
self._train_with_profiler(device_target="GPU")
|
||||
self._train_with_profiler(device_target="GPU", profile_memory=False)
|
||||
self._check_gpu_profiling_file()
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -168,12 +168,12 @@ class TestProfiler:
|
|||
@pytest.mark.env_onecard
|
||||
@security_off_wrap
|
||||
def test_ascend_profiler(self):
|
||||
self._train_with_profiler(device_target="Ascend")
|
||||
self._train_with_profiler(device_target="Ascend", profile_memory=True)
|
||||
self._check_d_profiling_file()
|
||||
|
||||
def _train_with_profiler(self, device_target):
|
||||
def _train_with_profiler(self, device_target, profile_memory):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
profiler = Profiler(profile_memory=True, output_path='data')
|
||||
profiler = Profiler(profile_memory=profile_memory, output_path='data')
|
||||
profiler_name = os.listdir(os.path.join(os.getcwd(), 'data'))[0]
|
||||
self.profiler_path = os.path.join(os.getcwd(), f'data/{profiler_name}/')
|
||||
ds_train = create_dataset(os.path.join(self.mnist_path, "train"))
|
||||
|
|
Loading…
Reference in New Issue