add summary st for ops

This commit is contained in:
ougongchang 2020-12-26 11:49:41 +08:00
parent f8df9e1be6
commit f2978e0721
3 changed files with 180 additions and 48 deletions

View File

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset base."""
import os
from mindspore import dataset as ds
from mindspore.common import dtype as mstype
from mindspore.dataset.transforms import c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.dataset.vision import c_transforms as CV
def create_mnist_dataset(mode='train', num_samples=2, batch_size=2):
"""create dataset for train or test"""
mnist_path = '/home/workspace/mindspore_dataset/mnist'
num_parallel_workers = 1
# define dataset
mnist_ds = ds.MnistDataset(os.path.join(mnist_path, mode), num_samples=num_samples, shuffle=False)
resize_height, resize_width = 32, 32
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
rescale_op = CV.Rescale(1.0 / 255.0, shift=0.0)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.batch(batch_size=batch_size, drop_remainder=True)
return mnist_ds

View File

@ -12,29 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test model train """
"""test SummaryCollector."""
import os
import re
import tempfile
import shutil
import tempfile
from collections import Counter
import pytest
from mindspore import dataset as ds
from mindspore import nn, Tensor, context
from mindspore.common.initializer import Normal
from mindspore.nn.metrics import Loss
from mindspore.nn.optim import Momentum
from mindspore.dataset.transforms import c_transforms as C
from mindspore.dataset.vision import c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.common.initializer import Normal
from mindspore.train import Model
from mindspore.train.callback import SummaryCollector
from tests.st.summary.dataset import create_mnist_dataset
from tests.summary_utils import SummaryReader
@ -52,6 +46,7 @@ class LeNet5(nn.Cell):
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
@ -72,6 +67,7 @@ class LeNet5(nn.Cell):
self.channel = Tensor(num_channel)
def construct(self, x):
"""construct."""
self.image_summary('image', x)
x = self.conv1(x)
self.histogram_summary('histogram', x)
@ -92,43 +88,9 @@ class LeNet5(nn.Cell):
return x
def create_dataset(data_path, num_samples=2):
"""create dataset for train or test"""
num_parallel_workers = 1
# define dataset
mnist_ds = ds.MnistDataset(data_path, num_samples=num_samples)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift=0.0)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
mnist_ds = mnist_ds.shuffle(buffer_size=10000) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size=2, drop_remainder=True)
return mnist_ds
class TestSummary:
"""Test summary collector the basic function."""
base_summary_dir = ''
mnist_path = '/home/workspace/mindspore_dataset/mnist'
@classmethod
def setup_class(cls):
@ -144,6 +106,7 @@ class TestSummary:
shutil.rmtree(cls.base_summary_dir)
def _run_network(self, dataset_sink_mode=False, num_samples=2, **kwargs):
"""run network."""
lenet = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optim = Momentum(lenet.trainable_params(), learning_rate=0.1, momentum=0.9)
@ -151,10 +114,10 @@ class TestSummary:
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_collector = SummaryCollector(summary_dir=summary_dir, collect_freq=2, **kwargs)
ds_train = create_dataset(os.path.join(self.mnist_path, "train"), num_samples=num_samples)
ds_train = create_mnist_dataset("train", num_samples=num_samples)
model.train(1, ds_train, callbacks=[summary_collector], dataset_sink_mode=dataset_sink_mode)
ds_eval = create_dataset(os.path.join(self.mnist_path, "test"))
ds_eval = create_mnist_dataset("test")
model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode, callbacks=[summary_collector])
return summary_dir
@ -202,10 +165,12 @@ class TestSummary:
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_summarycollector_user_defind(self):
"""Test SummaryCollector with user defind."""
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2, user_defind={'test': 'self test'})
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,
custom_lineage_data={'test': 'self test'})
tag_list = self._list_summary_tags(summary_dir)
# There will not record input data when dataset sink mode is True
@ -213,9 +178,9 @@ class TestSummary:
'fc2.weight/auto', 'loss/auto', 'histogram', 'image', 'scalar', 'tensor'}
assert set(expected_tags) == set(tag_list)
@staticmethod
def _list_summary_tags(summary_dir):
"""list summary tags."""
summary_file_path = ''
for file in os.listdir(summary_dir):
if re.search("_MS", file):

View File

@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test summary ops."""
import os
import shutil
import tempfile
import numpy as np
import pytest
from mindspore import nn, Tensor, context
from mindspore.common.initializer import Normal
from mindspore.nn.metrics import Loss
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from tests.st.summary.dataset import create_mnist_dataset
class LeNet5(nn.Cell):
"""LeNet network"""
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.scalar_summary = P.ScalarSummary()
self.image_summary = P.ImageSummary()
self.tensor_summary = P.TensorSummary()
self.channel = Tensor(num_channel)
def construct(self, x):
"""construct"""
self.image_summary('x', x)
self.tensor_summary('x', x)
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
self.scalar_summary('x_fc3', x[0][0])
return x
class TestSummaryOps:
"""Test summary ops."""
base_summary_dir = ''
@classmethod
def setup_class(cls):
"""Run before test this class."""
device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
context.set_context(mode=context.GRAPH_MODE, device_id=device_id)
cls.base_summary_dir = tempfile.mkdtemp(suffix='summary')
@classmethod
def teardown_class(cls):
"""Run after test this class."""
if os.path.exists(cls.base_summary_dir):
shutil.rmtree(cls.base_summary_dir)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_summary_ops(self):
"""Test summary operators."""
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)
ds_train_iter = ds_train.create_dict_iterator()
expected_data = next(ds_train_iter)['image'].asnumpy()
net = LeNet5()
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, loss_fn=loss, optimizer=optim, metrics={'loss': Loss()})
model.train(1, ds_train, dataset_sink_mode=False)
summary_data = _get_summary_tensor_data()
image_data = summary_data['x[:Image]'].asnumpy()
tensor_data = summary_data['x[:Tensor]'].asnumpy()
x_fc3 = summary_data['x_fc3[:Scalar]'].asnumpy()
assert np.allclose(expected_data, image_data)
assert np.allclose(expected_data, tensor_data)
assert not np.allclose(0, x_fc3)