From 3dc6f6f2d94bfc745c352348eb2a83c1458259e5 Mon Sep 17 00:00:00 2001 From: ougongchang Date: Wed, 24 Jun 2020 14:33:32 +0800 Subject: [PATCH] add more ut and st for SummaryCollector Has fixed collecting optimizer error when mode is eval --- .../train/callback/_summary_collector.py | 28 +-- tests/st/summary/test_davinci_summary.py | 99 --------- tests/st/summary/test_gpu_summary.py | 89 -------- tests/st/summary/test_summary.py | 194 ++++++++++++++++++ tests/summary_utils.py | 1 + .../train/summary/test_summary_collector.py | 193 +++++++++++++++++ 6 files changed, 402 insertions(+), 202 deletions(-) delete mode 100644 tests/st/summary/test_davinci_summary.py delete mode 100644 tests/st/summary/test_gpu_summary.py create mode 100644 tests/st/summary/test_summary.py diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index cff03ca398f..7ef890ec385 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -161,7 +161,7 @@ class SummaryCollector(Callback): self._check_custom_lineage_data(custom_lineage_data) self._custom_lineage_data = custom_lineage_data - self._optimizer = None + self._temp_optimizer = None self._has_saved_train_network = False self._has_saved_custom_data = False self._is_parse_loss_success = True @@ -369,15 +369,15 @@ class SummaryCollector(Callback): input_data = getattr(cb_params, 'train_dataset_element', None) if input_data is None: self._collect_specified_data['collect_input_data'] = False - logger.info("There is not a `train_dataset_element` in cb_params.") + logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.") return if isinstance(input_data, (list, tuple)): input_data = input_data[0] try: self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) - except ValueError as ex: - logger.warning(str(ex)) + except ValueError: + logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') self._collect_specified_data['collect_input_data'] = False return @@ -418,8 +418,8 @@ class SummaryCollector(Callback): try: self._record.add_value(PluginEnum.SCALAR.value, 'loss/auto', loss) - except ValueError as exc: - logger.warning(str(exc)) + except ValueError: + logger.warning("The output of network is not a scalar, so will not collect loss in SummaryCollector.") self._collect_specified_data['collect_metric'] = False def _get_loss(self, cb_params): @@ -438,7 +438,7 @@ class SummaryCollector(Callback): output = cb_params.net_outputs if output is None: - logger.warning("Can not find any output by this network.") + logger.warning("Can not find any output by this network, so will not collect loss in SummaryCollector.") self._is_parse_loss_success = False return None @@ -448,7 +448,7 @@ class SummaryCollector(Callback): # If the output is a list, since the default network returns loss first, # we assume that the first one is loss. loss = output[0] - elif isinstance(output, Tensor) and (not output.shape or output.shape == [1]): + elif isinstance(output, Tensor) and (not output.shape or output.shape == (1,)): loss_numpy = output.asnumpy() loss = float(np.atleast_1d(loss_numpy)[0]) else: @@ -473,15 +473,15 @@ class SummaryCollector(Callback): """ # 'optimizer_failed' means find optimizer failed, so we will not collect data about optimizer. optimizer_failed = 'Failed' - if self._optimizer == optimizer_failed: + if self._temp_optimizer == optimizer_failed: return None - if self._optimizer is not None: - return self._optimizer + if self._temp_optimizer is not None: + return self._temp_optimizer optimizer = cb_params.optimizer if optimizer is None: - network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_work + network = cb_params.train_network if cb_params.mode == 'train' else cb_params.eval_network optimizer = self._parse_optimizer_by_network(network) if optimizer is None or not isinstance(optimizer, Optimizer): @@ -489,7 +489,7 @@ class SummaryCollector(Callback): "optimizer, so we will not collect data about optimizer in SummaryCollector.") optimizer = None - self._optimizer = optimizer if optimizer is not None else optimizer_failed + self._temp_optimizer = optimizer if optimizer is not None else optimizer_failed return optimizer @@ -765,7 +765,7 @@ class SummaryCollector(Callback): cb_params (_InternalCallbackParam): Callback parameters. Returns: - Union[Loss_fn, None], a Cell object, if parse failed, will return None. + Union[Cell, None], a Cell object, if parse failed, will return None. """ loss_fn = cb_params.loss_fn if loss_fn is not None: diff --git a/tests/st/summary/test_davinci_summary.py b/tests/st/summary/test_davinci_summary.py deleted file mode 100644 index bc93afe364a..00000000000 --- a/tests/st/summary/test_davinci_summary.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test model train """ -import os -import numpy as np -from apply_momentum import ApplyMomentum -import mindspore.context as context -import mindspore.nn as nn -from mindspore.nn import wrap -from mindspore import Tensor, Model -from mindspore.common.api import ms_function -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits -from mindspore.ops import operations as P -from mindspore.train.summary.summary_record import SummaryRecord - -CUR_DIR = os.getcwd() -SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" - -context.set_context(device_target="Ascend") - - -class MsWrapper(nn.Cell): - def __init__(self, network): - super(MsWrapper, self).__init__(auto_prefix=False) - self._network = network - - @ms_function - def construct(self, *args): - return self._network(*args) - - -def me_train_tensor(net, input_np, label_np, epoch_size=2): - context.set_context(mode=context.GRAPH_MODE) - loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), - filter(lambda x: x.requires_grad, net.get_parameters())) - Model(net, loss, opt) - _network = wrap.WithLossCell(net, loss) - _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) - _train_net.set_train() - with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer: - for epoch in range(0, epoch_size): - print(f"epoch %d" % (epoch)) - output = _train_net(Tensor(input_np), Tensor(label_np)) - summary_writer.record(i) - print("********output***********") - print(output.asnumpy()) - - -def me_infer_tensor(net, input_np): - net.set_train() - net = MsWrapper(net) - output = net(Tensor(input_np)) - return output - - -def test_net(): - class Net(nn.Cell): - def __init__(self, cin, cout): - super(Net, self).__init__() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") - self.conv = nn.Conv2d(cin, cin, kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") - self.bn = nn.BatchNorm2d(cin, momentum=0.1, eps=0.0001) - self.add = P.TensorAdd() - self.relu = P.ReLU() - self.mean = P.ReduceMean(keep_dims=True) - self.reshape = P.Reshape() - self.dense = nn.Dense(cin, cout) - - def construct(self, input_x): - output = input_x - output = self.maxpool(output) - identity = output - output = self.conv(output) - output = self.bn(output) - output = self.add(output, identity) - output = self.relu(output) - output = self.mean(output, (-2, -1)) - output = self.reshape(output, (32, -1)) - output = self.dense(output) - return output - - net = Net(2048, 1001) - input_np = np.ones([32, 2048, 14, 14]).astype(np.float32) * 0.01 - label_np = np.ones([32]).astype(np.int32) - me_train_tensor(net, input_np, label_np) - # me_infer_tensor(net, input_np) diff --git a/tests/st/summary/test_gpu_summary.py b/tests/st/summary/test_gpu_summary.py deleted file mode 100644 index 9b4095b8d9d..00000000000 --- a/tests/st/summary/test_gpu_summary.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Summary gpu st.""" -import os -import random -import tempfile -import shutil - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore.common.tensor import Tensor -from mindspore.ops import operations as P -from mindspore.train.summary.summary_record import SummaryRecord - -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - -class SummaryNet(nn.Cell): - """Summary net.""" - def __init__(self, tag_tuple=None, scalar=1): - super(SummaryNet, self).__init__() - self.summary_s = P.ScalarSummary() - self.summary_i = P.ImageSummary() - self.summary_t = P.TensorSummary() - self.histogram_summary = P.HistogramSummary() - self.add = P.TensorAdd() - self.tag_tuple = tag_tuple - self.scalar = scalar - - def construct(self, x, y, image): - """Run summary net.""" - self.summary_i("image", image) - self.summary_s("x1", x) - z = self.add(x, y) - self.summary_t("z1", z) - self.histogram_summary("histogram", z) - return z - - -def train_summary_record(test_writer, steps): - """Train and record summary.""" - net = SummaryNet() - out_me_dict = {} - for i in range(0, steps): - x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32)) - y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32)) - image = Tensor(np.array([[[[1.2]]]]).astype(np.float32)) - out_put = net(x, y, image) - test_writer.record(i) - out_me_dict[i] = out_put.asnumpy() - return out_me_dict - - -class TestGpuSummary: - """Test Gpu summary.""" - summary_dir = tempfile.mkdtemp(suffix='_gpu_summary') - - def setup_method(self): - """Run before method.""" - if not os.path.exists(self.summary_dir): - os.mkdir(self.summary_dir) - - def teardown_method(self): - """Run after method.""" - if os.path.exists(self.summary_dir): - shutil.rmtree(self.summary_dir) - - @pytest.mark.level0 - @pytest.mark.platform_x86_gpu_training - @pytest.mark.env_onecard - def test_summary_step10_summaryrecord1(self): - """Test record 10 step summary.""" - with SummaryRecord(self.summary_dir) as test_writer: - train_summary_record(test_writer, steps=10) diff --git a/tests/st/summary/test_summary.py b/tests/st/summary/test_summary.py new file mode 100644 index 00000000000..b81d15514af --- /dev/null +++ b/tests/st/summary/test_summary.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test model train """ +import os +import re +import tempfile +import shutil + +import pytest + +from mindspore import dataset as ds +from mindspore import nn, Tensor, context +from mindspore.nn.metrics import Accuracy +from mindspore.nn.optim import Momentum +from mindspore.dataset.transforms import c_transforms as C +from mindspore.dataset.transforms.vision import c_transforms as CV +from mindspore.dataset.transforms.vision import Inter +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P +from mindspore.train import Model +from mindspore.train.callback import SummaryCollector + +from tests.summary_utils import SummaryReader + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +class LeNet5(nn.Cell): + """Define LeNet5 network.""" + def __init__(self, num_class=10, channel=1): + super(LeNet5, self).__init__() + self.num_class = num_class + self.conv1 = conv(channel, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.scalar_summary = P.ScalarSummary() + self.image_summary = P.ImageSummary() + self.histogram_summary = P.HistogramSummary() + self.tensor_summary = P.TensorSummary() + self.channel = Tensor(channel) + + def construct(self, data): + """define construct.""" + self.image_summary('image', data) + output = self.conv1(data) + self.histogram_summary('histogram', output) + output = self.relu(output) + self.tensor_summary('tensor', output) + output = self.max_pool2d(output) + output = self.conv2(output) + output = self.relu(output) + output = self.max_pool2d(output) + output = self.flatten(output) + output = self.fc1(output) + output = self.relu(output) + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + self.scalar_summary('scalar', self.channel) + return output + + +def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): + """create dataset for train or test""" + # define dataset + mnist_ds = ds.MnistDataset(data_path) + + resize_height, resize_width = 32, 32 + rescale = 1.0 / 255.0 + rescale_nml = 1 / 0.3081 + shift_nml = -1 * 0.1307 / 0.3081 + + # define map operations + resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode + rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) + rescale_op = CV.Rescale(rescale, shift=0.0) + hwc2chw_op = CV.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + # apply map operations on images + mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) + mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) + + # apply DatasetOps + mnist_ds = mnist_ds.shuffle(buffer_size=10000) # 10000 as in LeNet train script + mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) + mnist_ds = mnist_ds.repeat(repeat_size) + + return mnist_ds + + +class TestSummary: + """Test summary collector the basic function.""" + base_summary_dir = '' + mnist_path = '/home/workspace/mindspore_dataset/mnist' + + @classmethod + def setup_class(cls): + """Run before test this class.""" + cls.base_summary_dir = tempfile.mkdtemp(suffix='summary') + + @classmethod + def teardown_class(cls): + """Run after test this class.""" + if os.path.exists(cls.base_summary_dir): + shutil.rmtree(cls.base_summary_dir) + + @pytest.mark.level0 + @pytest.mark.platform_x86_ascend_training + @pytest.mark.env_onecard + def test_summary_ascend(self): + """Test summary ascend.""" + context.set_context(mode=context.GRAPH_MODE) + self._run_network() + + def _run_network(self, dataset_sink_mode=True): + lenet = LeNet5() + loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(lenet, loss_fn=loss, optimizer=optim, metrics={'acc': Accuracy()}) + summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir) + summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=1) + + ds_train = create_dataset(os.path.join(self.mnist_path, "train")) + model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode) + + ds_eval = create_dataset(os.path.join(self.mnist_path, "test")) + model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector]) + + self._check_summary_result(summary_dir) + + @staticmethod + def _check_summary_result(summary_dir): + summary_file_path = '' + for file in os.listdir(summary_dir): + if re.search("_MS", file): + summary_file_path = os.path.join(summary_dir, file) + break + + assert not summary_file_path + + with SummaryReader(summary_file_path) as summary_reader: + tags = set() + + # Read the event that record by SummaryCollector.begin + summary_reader.read_event() + + summary_event = summary_reader.read_event() + for value in summary_event.summary.value: + tags.add(value.tag) + + # There will not record input data when dataset sink mode is True + expected_tags = ['conv1.weight/auto', 'conv2.weight/auto', 'fc1.weight/auto', 'fc1.bias/auto', + 'fc2.weight/auto', 'histogram', 'image', 'scalar', 'tensor'] + assert set(expected_tags) == tags diff --git a/tests/summary_utils.py b/tests/summary_utils.py index 826a3106e53..cc2070f9c6a 100644 --- a/tests/summary_utils.py +++ b/tests/summary_utils.py @@ -38,6 +38,7 @@ class SummaryReader: def __init__(self, canonical_file_path, ignore_version_event=True): self._file_path = canonical_file_path self._ignore_version_event = ignore_version_event + self._file_handler = None def __enter__(self): self._file_handler = open(self._file_path, "rb") diff --git a/tests/ut/python/train/summary/test_summary_collector.py b/tests/ut/python/train/summary/test_summary_collector.py index 5e7f8e662cf..1390d29bc19 100644 --- a/tests/ut/python/train/summary/test_summary_collector.py +++ b/tests/ut/python/train/summary/test_summary_collector.py @@ -16,9 +16,50 @@ import os import tempfile import shutil +from importlib import import_module +from unittest import mock + +import numpy as np import pytest +from mindspore import Tensor +from mindspore import Parameter from mindspore.train.callback import SummaryCollector +from mindspore.train.callback import _InternalCallbackParam +from mindspore.train.summary.enum import ModeEnum, PluginEnum +from mindspore.train.summary import SummaryRecord +from mindspore.nn import Cell +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.ops.operations import TensorAdd + + +_VALUE_CACHE = list() + + +def add_value(plugin, name, value): + """This function is mock the function in SummaryRecord.""" + global _VALUE_CACHE + _VALUE_CACHE.append((plugin, name, value)) + + +def get_value(): + """Get the value which is added by add_value function.""" + global _VALUE_CACHE + + value = _VALUE_CACHE + _VALUE_CACHE = list() + return value + + +class CustomNet(Cell): + """Define custom netwrok.""" + def __init__(self): + super(CustomNet, self).__init__() + self.add = TensorAdd + self.optimizer = Optimizer(learning_rate=1, parameters=[Parameter(Tensor(1), 'weight')]) + + def construct(self, data): + return data class TestSummaryCollector: @@ -34,6 +75,10 @@ class TestSummaryCollector: if os.path.exists(self.base_summary_dir): shutil.rmtree(self.base_summary_dir) + def teardown_method(self): + """Run after each test function.""" + get_value() + @pytest.mark.parametrize("summary_dir", [1234, None, True, '']) def test_params_with_summary_dir_value_error(self, summary_dir): """Test the exception scenario for summary dir.""" @@ -182,3 +227,151 @@ class TestSummaryCollector: f'bug got {type(param_value).__name__}.' assert expected_msg == str(exc.value) + + def test_check_callback_with_multi_instances(self): + """Use multi SummaryCollector instances to test check_callback function.""" + cb_params = _InternalCallbackParam() + cb_params.list_callback = [ + SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)), + SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)) + ] + with pytest.raises(ValueError) as exc: + SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params) + assert f"more than one SummaryCollector instance in callback list" in str(exc.value) + + def test_collect_input_data_with_train_dataset_element_none(self): + """Test the param 'train_dataset_element' in cb_params is none.""" + cb_params = _InternalCallbackParam() + cb_params.train_dataset_element = None + summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) + summary_collector._collect_input_data(cb_params) + assert not summary_collector._collect_specified_data['collect_input_data'] + + @mock.patch.object(SummaryRecord, 'add_value') + def test_collect_input_data_success(self, mock_add_value): + """Mock a image data, and collect image data success.""" + mock_add_value.side_effect = add_value + cb_params = _InternalCallbackParam() + image_data = Tensor(np.random.randint(0, 255, size=(1, 1, 1, 1)).astype(np.uint8)) + cb_params.train_dataset_element = image_data + with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: + summary_collector._collect_input_data(cb_params) + # Note Here need to asssert the result and expected data + + @mock.patch.object(SummaryRecord, 'add_value') + def test_collect_dataset_graph_success(self, mock_add_value): + """Test collect dataset graph.""" + dataset = import_module('mindspore.dataset') + mock_add_value.side_effect = add_value + cb_params = _InternalCallbackParam() + cb_params.train_dataset = dataset.MnistDataset(dataset_dir=tempfile.mkdtemp(dir=self.base_summary_dir)) + cb_params.mode = ModeEnum.TRAIN.value + with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: + summary_collector._collect_dataset_graph(cb_params) + plugin, name, _ = get_value()[0] + assert plugin == 'dataset_graph' + assert name == 'train_dataset' + + @pytest.mark.parametrize("net_output, expected_loss", [ + (1, Tensor(1)), + ([1], Tensor(1)), + ([Tensor(1)], Tensor(1)), + (Tensor([1]), Tensor(1)), + (tuple([1]), Tensor(1)), + (None, None) + ]) + def test_get_loss(self, net_output, expected_loss): + """Test get loss success and failed.""" + cb_params = _InternalCallbackParam() + cb_params.net_outputs = net_output + summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) + + assert summary_collector._is_parse_loss_success + assert summary_collector._get_loss(cb_params) == expected_loss + + if expected_loss is None: + assert not summary_collector._is_parse_loss_success + + def test_get_optimizer_from_cb_params_success(self): + """Test get optimizer success from cb params.""" + cb_params = _InternalCallbackParam() + cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=[Parameter(Tensor(1), 'weight')]) + summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) + optimizer = summary_collector._get_optimizer(cb_params) + assert optimizer == cb_params.optimizer + + # Test get optimizer again + assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer + + @pytest.mark.parametrize('mode', [ModeEnum.TRAIN.value, ModeEnum.EVAL.value]) + def test_get_optimizer_from_network(self, mode): + """Get optimizer from train network""" + cb_params = _InternalCallbackParam() + cb_params.optimizer = None + cb_params.mode = mode + if mode == ModeEnum.TRAIN.value: + cb_params.train_network = CustomNet() + else: + cb_params.eval_network = CustomNet() + summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) + optimizer = summary_collector._get_optimizer(cb_params) + assert isinstance(optimizer, Optimizer) + + def test_get_optimizer_failed(self): + """Test get optimizer failed.""" + class Net(Cell): + """Define net.""" + def __init__(self): + super(Net, self).__init__() + self.add = TensorAdd() + + def construct(self, data): + return data + + cb_params = _InternalCallbackParam() + cb_params.optimizer = None + cb_params.train_network = Net() + cb_params.mode = ModeEnum.TRAIN.value + summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) + optimizer = summary_collector._get_optimizer(cb_params) + assert optimizer is None + assert summary_collector._temp_optimizer == 'Failed' + + # Test get optimizer again + optimizer = summary_collector._get_optimizer(cb_params) + assert optimizer is None + assert summary_collector._temp_optimizer == 'Failed' + + @pytest.mark.parametrize("histogram_regular, expected_names, expected_values", [ + ( + 'conv1|conv2', + ['conv1.weight1/auto', 'conv2.weight2/auto', 'conv1.bias1/auto'], + [1, 2, 3] + ), + ( + None, + ['conv1.weight1/auto', 'conv2.weight2/auto', 'conv1.bias1/auto', 'conv3.bias/auto', 'conv5.bias/auto'], + [1, 2, 3, 4, 5] + ) + ]) + @mock.patch.object(SummaryRecord, 'add_value') + def test_collect_histogram_from_regular(self, mock_add_value, histogram_regular, expected_names, expected_values): + """Test collect histogram from regular success.""" + mock_add_value.side_effect = add_value + cb_params = _InternalCallbackParam() + parameters = [ + Parameter(Tensor(1), 'conv1.weight1'), + Parameter(Tensor(2), 'conv2.weight2'), + Parameter(Tensor(3), 'conv1.bias1'), + Parameter(Tensor(4), 'conv3.bias'), + Parameter(Tensor(5), 'conv5.bias'), + Parameter(Tensor(6), 'conv6.bias'), + ] + cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=parameters) + with SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) as summary_collector: + summary_collector._collect_specified_data['histogram_regular'] = histogram_regular + summary_collector._collect_histogram(cb_params) + result = get_value() + assert PluginEnum.HISTOGRAM.value == result[0][0] + assert expected_names == [data[1] for data in result] + assert expected_values == [data[2] for data in result]