!26801 add loss landscape ut and st

Merge pull request !26801 from Songyuanwei/landscape_ci
This commit is contained in:
i-robot 2021-12-01 01:49:43 +00:00 committed by Gitee
commit fcac7351aa
6 changed files with 526 additions and 55 deletions

View File

@ -33,7 +33,7 @@ from mindspore.train.summary import SummaryRecord
from mindspore.train.summary.enums import PluginEnum
from mindspore.train.anf_ir_pb2 import DataType
from mindspore.train._utils import check_value_type, _make_directory
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.nn.metrics import get_metrics
from mindspore import context
@ -170,9 +170,10 @@ class SummaryLandscape:
1. SummaryLandscape only supports Linux systems.
Args:
summary_dir(str): The path of summary is used to save the model weight,
summary_dir (str): The path of summary is used to save the model weight,
metadata and other data required for create landscape.
intervals (Union[List[List[int]], None]): Specifies the interval in which the loss landscape.
Default: None.
Examples:
>>> from mindspore.train.callback import SummaryLandscape
>>> import mindspore.nn as nn
@ -224,7 +225,27 @@ class SummaryLandscape:
self._epoch_group[i].append(j)
def save_loss_and_model_params(self, cur_num, unit, backbone, loss):
"""Save model params and loss."""
"""
Save model params and loss.
Args:
cur_num (int): The serial number of the current model and loss to be saved.
unit (str): The type of model and loss to save. Optional: epoch/step.
backbone (Cell): A backbone network.
loss (int): The loss of current model to be saved.
Examples:
>>> from mindspore.train.callback import SummaryLandscape
>>>
>>> if __name__ == '__main__':
... summary_dir = './summary/'
... summary_landscape = SummaryLandscape(summary_dir=summary_dir)
... cur_num = 1
... unit = 'epoch'
... backbone = LeNet5(10)
... loss = 1.5
... summary_landscape.save_loss_and_model_params(cur_num, unit, backbone, loss)
"""
self._save_model_params(cur_num, unit, backbone, loss)
def _save_model_params(self, cur_num, unit, backbone, loss):
@ -260,7 +281,30 @@ class SummaryLandscape:
shutil.rmtree(self._ckpt_dir, ignore_errors=True)
def save_metadata(self, step_per_epoch, unit, num_samples, landscape_size, create_landscape):
"""Save meta data to json file."""
"""
Save meta data to json file.
Args:
step_per_epoch (int): The steps of an epoch for current model.
unit (str): The type of model to save. Optional: epoch/step.
num_samples (int): The size of the dataset used to create the loss landscape.
landscape_size (int): Specify the image resolution of the generated loss landscape.
create_landscape (List[bool, bool]): Select how to create loss landscape.
Training process loss landscape(train) and Training result loss landscape(result).
Examples:
>>> from mindspore.train.callback import SummaryLandscape
>>>
>>> if __name__ == '__main__':
... summary_dir = './summary/'
... summary_landscape = SummaryLandscape(summary_dir=summary_dir)
... step_per_epoch = 175
... unit = 'epoch'
... num_samples = 2048
... landscape_size = 40
... create_landscape = {"train": True, "result", False}
... summary_landscape.save_metadata(step_per_epoch, unit, num_samples, landscape_size, create_landscape)
"""
data = {
"epoch_group": self._epoch_group,
"model_params_file_map": self._model_params_file_map,
@ -276,12 +320,12 @@ class SummaryLandscape:
def gen_landscapes_with_multi_process(self, callback_fn, collect_landscape=None,
device_ids=None, device_target='Ascend', output=None):
"""
r"""
Use the multi process to generate landscape.
Args:
callback_fn (python function): A python function object. User needs to write a function,
callback_ fn, it has no input, and the return requirements are as follows.
it has no input, and the return requirements are as follows.
- mindspore.train.Model: User's model object.
- mindspore.nn.Cell: User's network object.
@ -342,16 +386,17 @@ class SummaryLandscape:
with open(json_path, 'w') as file:
json.dump(data, file)
for interval, landscape in self.list_landscapes(callback_fn=callback_fn,
device_ids=device_ids,
device_target=device_target):
for interval, landscape in self._list_landscapes(callback_fn=callback_fn,
device_ids=device_ids,
device_target=device_target):
summary_record.add_value(PluginEnum.LANDSCAPE.value, f'landscape_{str(interval)}', landscape)
summary_record.record(0)
summary_record.flush()
summary_record.close()
def list_landscapes(self, callback_fn, device_ids=None, device_target='Ascend'):
def _list_landscapes(self, callback_fn, device_ids=None, device_target='Ascend'):
"""Create landscape with single device and list all landscape."""
json_path = os.path.join(self._ckpt_dir, 'train_metadata.json')
if not os.path.exists(json_path):
raise FileNotFoundError(f'train_metadata json file path not exists,'
@ -708,7 +753,6 @@ class SummaryLandscape:
"""
logger.info("start to cont loss")
vals = list()
dataset_sink_mode = (context.get_context('device_target') == 'Ascend')
al_item = 0
for i, _ in enumerate(alph):
@ -725,7 +769,7 @@ class SummaryLandscape:
load_param_into_net(network, parameters_dict)
del parameters_dict
loss = self._loss_compute(model, ds_eval, metrics, dataset_sink_mode)
loss = self._loss_compute(model, ds_eval, metrics)
logger.info("%s/%s loss: %s." % (i+1, len(alph), loss))
vals = np.append(vals, loss['Loss'])
@ -740,23 +784,21 @@ class SummaryLandscape:
parameter.set_data(Tensor(data_target))
return parameter
def _loss_compute(self, model, data, metrics, dataset_sink_mode=False):
def _loss_compute(self, model, data, metrics):
"""Compute loss."""
dataset_sink_mode = False
self._metric_fns = get_metrics(metrics)
for metric in self._metric_fns.values():
metric.clear()
network = model.train_network
dataset_helper = DatasetHelper(data, dataset_sink_mode)
if dataset_sink_mode:
network = connect_network_with_dataset(network, dataset_helper)
network.set_train(True)
network.phase = 'train'
for inputs in dataset_helper:
if not dataset_sink_mode:
inputs = transfer_tensor_to_tuple(inputs)
inputs = transfer_tensor_to_tuple(inputs)
outputs = network(*inputs)
self._update_metrics(outputs)

View File

@ -38,7 +38,6 @@ from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type, _make_directory
from ..._c_expression import security
from ...common.api import _cell_graph_executor
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
HYPER_CONFIG_LEN_LIMIT = 100000
@ -115,23 +114,23 @@ class SummaryCollector(Callback):
Default: None, it means only the first five parameters are collected.
- collect_landscape (Union[dict,None]): Collect the parameters needed to create the loss landscape.
- landscape_size (int): Specify the image resolution of the generated loss landscape.
For example, if it is set to 128, the resolution of the landscape is 128 * 128.
The calculation time increases with the increase of resolution.
Default: 40. Optional values: between 3 and 256.
- unit (str): Specify the interval strength of the training process. Optional: epoch/step.
- create_landscape (List[bool, bool]): Select how to create loss landscape.
Training process loss landscape(train) and Training result loss landscape(result).
Default: {"train": True, "result": True}. Optional: True/False.
- num_samples (int): The size of the dataset used to create the loss landscape.
For example, in image dataset, You can set num_samples is 128,
which means that 128 images are used to create loss landscape.
Default: 128.
- intervals (List[List[int]]): Specifies the interval
in which the loss landscape. For example: If the user wants to
crate loss landscape of two training processes, they are 1-5 epoch
and 6-10 epoch respectively. They anc set [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]].
Note: Each interval have at least three epochs.
- landscape_size (int): Specify the image resolution of the generated loss landscape.
For example, if it is set to 128, the resolution of the landscape is 128 * 128.
The calculation time increases with the increase of resolution.
Default: 40. Optional values: between 3 and 256.
- unit (str): Specify the interval strength of the training process. Optional: epoch/step.
- create_landscape (List[bool, bool]): Select how to create loss landscape.
Training process loss landscape(train) and Training result loss landscape(result).
Default: {"train": True, "result": True}. Optional: True/False.
- num_samples (int): The size of the dataset used to create the loss landscape.
For example, in image dataset, You can set num_samples is 128,
which means that 128 images are used to create loss landscape.
Default: 128.
- intervals (List[List[int]]): Specifies the interval
in which the loss landscape. For example: If the user wants to
create loss landscape of two training processes, they are 1-5 epoch
and 6-10 epoch respectively. They anc set [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]].
Note: Each interval have at least three epochs.
keep_default_action (bool): This field affects the collection behavior of the 'collect_specified_data' field.
True: it means that after specified data is set, non-specified data is collected as the default behavior.
@ -495,14 +494,15 @@ class SummaryCollector(Callback):
self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value)
collect_landscape = self._collect_specified_data.get('collect_landscape')
intervals = collect_landscape.get('intervals')
collect_interval = False
for interval in intervals:
if "cur_step_num" in cb_params:
if cb_params.cur_step_num in interval:
collect_interval = True
break
break
if collect_landscape is not None:
intervals = collect_landscape.get('intervals')
collect_interval = False
for interval in intervals:
if "cur_step_num" in cb_params:
if cb_params.cur_step_num in interval:
collect_interval = True
break
if collect_landscape and collect_landscape.get('unit', 'step') == 'step' and collect_interval:
self._save_model_params_for_landscape(cb_params)
@ -533,14 +533,15 @@ class SummaryCollector(Callback):
def epoch_end(self, run_context):
cb_params = run_context.original_args()
collect_landscape = self._collect_specified_data.get('collect_landscape')
intervals = collect_landscape.get('intervals')
collect_interval = False
for interval in intervals:
if "cur_epoch_num" in cb_params:
if cb_params.cur_epoch_num in interval:
collect_interval = True
break
break
if collect_landscape is not None:
intervals = collect_landscape.get('intervals')
collect_interval = False
for interval in intervals:
if "cur_epoch_num" in cb_params:
if cb_params.cur_epoch_num in interval:
collect_interval = True
break
if collect_landscape and collect_landscape.get('unit', 'step') == 'epoch' and collect_interval:
self._save_model_params_for_landscape(cb_params)
self._record.flush()
@ -686,11 +687,12 @@ class SummaryCollector(Callback):
return
network = cb_params.train_network if cb_params.mode == ModeEnum.TRAIN.value else cb_params.eval_network
graph_proto = _cell_graph_executor.get_optimize_graph_proto(network)
graph_proto = network.get_func_graph_proto()
if graph_proto is None:
logger.warning("Can not get graph proto, it may not be 'GRAPH_MODE' in context currently, "
"so SummaryCollector will not collect graph.")
return
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
def _collect_metric(self, cb_params):

View File

@ -17,21 +17,34 @@ import os
import re
import shutil
import tempfile
import json
from collections import Counter
import numpy as np
import pytest
from mindspore.common import set_seed
from mindspore import nn, Tensor, context
from mindspore.common.initializer import Normal
from mindspore.nn.metrics import Loss
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import SummaryCollector
from mindspore.train.callback import SummaryCollector, SummaryLandscape
from tests.st.summary.dataset import create_mnist_dataset
from tests.summary_utils import SummaryReader
from tests.security_utils import security_off_wrap
set_seed(1)
def callback_fn():
"""A python function job"""
network = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
metrics = {"Loss": Loss()}
model = Model(network, loss, metrics=metrics)
ds_train = create_mnist_dataset("train")
return model, network, ds_train, metrics
class LeNet5(nn.Cell):
"""
@ -224,3 +237,106 @@ class TestSummary:
tensors.append(file)
return tensors
def _train_network(self, epoch=3, dataset_sink_mode=False, num_samples=2, **kwargs):
"""run network."""
lenet = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()})
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs)
ds_train = create_mnist_dataset("train", num_samples=num_samples)
model.train(epoch, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
return summary_dir
@staticmethod
def _list_summary_collect_landscape_tags(summary_dir):
"""list summary landscape tags."""
summary_dir_path = ''
for file in os.listdir(summary_dir):
if re.search("ckpt_dir", file):
summary_dir_path = os.path.join(summary_dir, file)
break
assert summary_dir_path
summary_file_path = ''
for file in os.listdir(summary_dir_path):
if re.search(".json", file):
summary_file_path = os.path.join(summary_dir_path, file)
break
assert summary_file_path
tags = list()
with open(summary_file_path, 'r') as file:
data = json.load(file)
for key, value in data.items():
tags.append(key)
assert value
return tags
@staticmethod
def _list_landscape_tags(summary_dir):
"""list landscape tags."""
expected_tags = {'landscape_[1, 3]', 'landscape_[3]'}
summary_list = []
for file in os.listdir(summary_dir):
if re.search("_MS", file):
summary_file_path = os.path.join(summary_dir, file)
summary_list = summary_list + [summary_file_path]
assert summary_list
tags = []
for summary_path in summary_list:
with SummaryReader(summary_path) as summary_reader:
while True:
summary_event = summary_reader.read_event()
if not summary_event:
break
for value in summary_event.summary.value:
if value.tag in expected_tags:
tags.append(value.loss_landscape.landscape.z.float_data)
break
return tags
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@security_off_wrap
def test_summary_collector_landscape(self):
"""Test summary collector with landscape."""
interval_1 = [1, 2, 3]
num_samples = 2
summary_dir = self._train_network(epoch=3, num_samples=num_samples,
collect_specified_data={'collect_landscape':
{'landscape_size': 4,
'unit': 'epoch',
'create_landscape': {'train': True,
'result': True},
'num_samples': num_samples,
'intervals': [interval_1]}})
tag_list = self._list_summary_collect_landscape_tags(summary_dir)
expected_tags = {'epoch_group', 'model_params_file_map', 'step_per_epoch', 'unit', 'num_samples',
'landscape_size', 'create_landscape', 'loss_map'}
assert set(expected_tags) == set(tag_list)
device_target = context.get_context("device_target")
device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
summary_landscape = SummaryLandscape(summary_dir)
summary_landscape.gen_landscapes_with_multi_process(callback_fn, device_ids=[device_id],
device_target=device_target)
expected_pca_value = np.array([2.0876417, 2.0871262, 2.0866107, 2.0860953, 2.0871796, 2.0866641, 2.0861477,
2.0856318, 2.0867180, 2.0862016, 2.0856854, 2.0851683, 2.0862572, 2.0857398,
2.0852231, 2.0847058])
expected_random_value = np.array([2.0066809, 1.9905004, 1.9798302, 1.9742643, 2.0754160, 2.0571522, 2.0442397,
2.0365926, 2.1506545, 2.1299571, 2.1143755, 2.1042551, 2.2315959, 2.2083559,
2.1895625, 2.1762595])
tag_list_landscape = self._list_landscape_tags(summary_dir)
assert np.all(expected_pca_value - tag_list_landscape[0] < 1.e-3)
assert np.all(expected_random_value - tag_list_landscape[1] < 1.e-3)

View File

@ -0,0 +1,110 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset base and LeNet."""
import os
from mindspore import dataset as ds
from mindspore.common import dtype as mstype
from mindspore.dataset.transforms import c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.dataset.vision import c_transforms as CV
from mindspore import nn, Tensor
from mindspore.common.initializer import Normal
from mindspore.ops import operations as P
def create_mnist_dataset(mode='train', num_samples=2, batch_size=2):
"""create dataset for train or test"""
mnist_path = '/home/workspace/mindspore_dataset/mnist'
num_parallel_workers = 1
# define dataset
mnist_ds = ds.MnistDataset(os.path.join(mnist_path, mode), num_samples=num_samples, shuffle=False)
resize_height, resize_width = 32, 32
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
rescale_op = CV.Rescale(1.0 / 255.0, shift=0.0)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.batch(batch_size=batch_size, drop_remainder=True)
return mnist_ds
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Number of classes. Default: 10.
num_channel (int): Number of channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.scalar_summary = P.ScalarSummary()
self.image_summary = P.ImageSummary()
self.histogram_summary = P.HistogramSummary()
self.tensor_summary = P.TensorSummary()
self.channel = Tensor(num_channel)
def construct(self, x):
"""construct."""
self.image_summary('image', x)
x = self.conv1(x)
self.histogram_summary('histogram', x)
x = self.relu(x)
self.tensor_summary('tensor', x)
x = self.relu(x)
x = self.max_pool2d(x)
self.scalar_summary('scalar', self.channel)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x

View File

@ -477,3 +477,95 @@ class TestSummaryCollector:
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert str(exc.value) == 'The Summary is not supported, please without `-s on` and recompile source.'
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_landscape': 123
}
])
def test_params_collect_specified_data_value_type_error(self, collect_specified_data):
"""Test the value of collect_specified_data_param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data)[0]
param_value = collect_specified_data[param_name]
expected_type = "['dict', 'NoneType']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
@security_off_wrap
def test_params_with_collect_landscape_unexpected_key(self):
"""Test the collect landscape parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
data = {'unexpected_key': "value"}
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data={'collect_landscape': data})
expected_msg = f"For `collect_landscape` the keys {set(data)} are unsupported"
assert expected_msg in str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_landscape': {'landscape_size': None}
},
{
'collect_landscape': {'unit': 123}
},
{
'collect_landscape': {'create_landscape': 123}
},
{
'collect_landscape': {'num_samples': None}
},
{
'collect_landscape': {'intervals': None}
},
])
def test_params_collect_landscape_value_type_error(self, collect_specified_data):
"""Test the value of collect_landscape_param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data["collect_landscape"])[0]
param_value = collect_specified_data["collect_landscape"][param_name]
if param_name in ['landscape_size', 'num_samples']:
expected_type = "['int']"
elif param_name == 'unit':
expected_type = "['str']"
elif param_name == 'create_landscape':
expected_type = "['dict']"
else:
expected_type = "['list']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
@security_off_wrap
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_landscape': {'create_landscape': {'train': None}}
},
{
'collect_landscape': {'create_landscape': {'result': 123}}
}
])
def test_params_create_landscape_value_type_error(self, collect_specified_data):
"""Test the value of create_landscape_param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data["collect_landscape"]["create_landscape"])[0]
param_value = collect_specified_data["collect_landscape"]["create_landscape"][param_name]
expected_type = "['bool']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test create landscape."""
import os
import shutil
import tempfile
import pytest
from mindspore.common import set_seed
from mindspore import nn, context
from mindspore.nn.metrics import Loss
from mindspore.train import Model
from mindspore.train.callback import SummaryLandscape
from tests.security_utils import security_off_wrap
from tests.ut.python.train.dataset import create_mnist_dataset, LeNet5
set_seed(1)
_VALUE_CACHE = list()
def get_value():
"""Get the value which is added by add_value function."""
global _VALUE_CACHE
value = _VALUE_CACHE
_VALUE_CACHE = list()
return value
def callback_fn():
"""A python function job"""
network = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
metrics = {"Loss": Loss()}
model = Model(network, loss, metrics=metrics)
ds_train = create_mnist_dataset("train")
return model, network, ds_train, metrics
class TestLandscape:
"""Test the exception parameter for landscape."""
base_summary_dir = ''
def setup_class(self):
"""Run before test this class."""
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
def teardown_class(self):
"""Run after test this class."""
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
def teardown_method(self):
"""Run after each test function."""
get_value()
@security_off_wrap
@pytest.mark.parametrize("collect_landscape", [
{
'landscape_size': None
},
{
'create_landscape': None
},
{
'num_samples': None
},
{
'intervals': None
},
])
def test_params_gen_landscape_with_multi_process_value_type_error(self, collect_landscape):
"""Test the value of gen_landscape_with_multi_process param."""
device_target = context.get_context("device_target")
device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_landscape = SummaryLandscape(summary_dir)
with pytest.raises(TypeError) as exc:
summary_landscape.gen_landscapes_with_multi_process(
callback_fn,
collect_landscape=collect_landscape,
device_ids=[device_id],
device_target=device_target
)
param_name = list(collect_landscape)[0]
param_value = collect_landscape[param_name]
if param_name in ['landscape_size', 'num_samples']:
expected_type = "['int']"
elif param_name == 'unit':
expected_type = "['str']"
elif param_name == 'create_landscape':
expected_type = "['dict']"
else:
expected_type = "['list']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'but got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)