forked from mindspore-Ecosystem/mindspore
!261 Add summary module python code to extract histogram info from tensor
Merge pull request !261 from wenkai/histogram_dev3cp
This commit is contained in:
commit
18580a7867
|
@ -61,6 +61,30 @@ message Summary {
|
|||
required bytes encoded_image = 4;
|
||||
}
|
||||
|
||||
message Histogram {
|
||||
message bucket{
|
||||
// Count number of values fallen in [left, left + width).
|
||||
// For the right most bucket, range is [left, left + width].
|
||||
required double left = 1;
|
||||
required double width = 2;
|
||||
required int64 count = 3;
|
||||
}
|
||||
|
||||
repeated bucket buckets = 1;
|
||||
optional int64 nan_count = 2;
|
||||
optional int64 pos_inf_count = 3;
|
||||
optional int64 neg_inf_count = 4;
|
||||
|
||||
// max, min, sum will not take nan and inf into account.
|
||||
// If there is no valid value in tensor, max will be nan, min will be nan, sum will be 0.
|
||||
optional double max = 5;
|
||||
optional double min = 6;
|
||||
optional double sum = 7;
|
||||
|
||||
// total number of values, including nan and inf
|
||||
optional int64 count = 8;
|
||||
}
|
||||
|
||||
message Value {
|
||||
// Tag name for the data.
|
||||
required string tag = 1;
|
||||
|
@ -70,6 +94,7 @@ message Summary {
|
|||
float scalar_value = 3;
|
||||
Image image = 4;
|
||||
TensorProto tensor = 8;
|
||||
Histogram histogram = 9;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -71,12 +71,14 @@ class SummaryType(Enum):
|
|||
TENSOR (Number): Summary TENSOR enum.
|
||||
IMAGE (Number): Summary image enum.
|
||||
GRAPH (Number): Summary graph enum.
|
||||
HISTOGRAM (Number): Summary histogram enum.
|
||||
INVALID (Number): Unknow type.
|
||||
"""
|
||||
SCALAR = 1 # Scalar summary
|
||||
TENSOR = 2 # Tensor summary
|
||||
IMAGE = 3 # Image summary
|
||||
GRAPH = 4 # graph
|
||||
HISTOGRAM = 5 # Histogram Summary
|
||||
INVALID = 0xFF # unknow type
|
||||
|
||||
|
||||
|
@ -148,7 +150,7 @@ def package_summary_event(data_id, step):
|
|||
"""
|
||||
data_list = get_summary_data(data_id)
|
||||
if data_list is None:
|
||||
logger.error("The step(%r) does not have record data.", self.step)
|
||||
logger.error("The step(%r) does not have record data.", step)
|
||||
del_summary_data(data_id)
|
||||
# create the event of summary
|
||||
summary_event = Event()
|
||||
|
@ -177,6 +179,12 @@ def package_summary_event(data_id, step):
|
|||
summary_value.tag = tag
|
||||
summary_image = summary_value.image
|
||||
_get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
|
||||
elif summary_type is SummaryType.HISTOGRAM:
|
||||
logger.debug("Now process Histogram summary, tag = %r", tag)
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
summary_histogram = summary_value.histogram
|
||||
_fill_histogram_summary(tag, data, summary_histogram)
|
||||
else:
|
||||
# The data is invalid ,jump the data
|
||||
logger.error("Summary type is error, tag = %r", tag)
|
||||
|
@ -284,6 +292,74 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor):
|
|||
return summary_tensor
|
||||
|
||||
|
||||
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
|
||||
"""
|
||||
Package the histogram summary.
|
||||
|
||||
Args:
|
||||
tag (str): Summary tag describe.
|
||||
np_value (np.array): Summary data.
|
||||
summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data.
|
||||
"""
|
||||
logger.debug("Set(%r) the histogram summary value", tag)
|
||||
# Default bucket for tensor with no valid data.
|
||||
default_bucket_left = -0.5
|
||||
default_bucket_width = 1.0
|
||||
|
||||
if np_value.size == 0:
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = default_bucket_left
|
||||
bucket.width = default_bucket_width
|
||||
bucket.count = 0
|
||||
|
||||
summary_histogram.nan_count = 0
|
||||
summary_histogram.pos_inf_count = 0
|
||||
summary_histogram.neg_inf_count = 0
|
||||
|
||||
summary_histogram.max = 0
|
||||
summary_histogram.min = 0
|
||||
summary_histogram.sum = 0
|
||||
|
||||
summary_histogram.count = 0
|
||||
|
||||
return
|
||||
|
||||
summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value))
|
||||
summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value))
|
||||
summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value))
|
||||
summary_histogram.count = np_value.size
|
||||
|
||||
masked_value = np.ma.masked_invalid(np_value)
|
||||
tensor_max = masked_value.max()
|
||||
tensor_min = masked_value.min()
|
||||
tensor_sum = masked_value.sum()
|
||||
|
||||
# No valid value in tensor.
|
||||
if tensor_max is np.ma.masked:
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = default_bucket_left
|
||||
bucket.width = default_bucket_width
|
||||
bucket.count = 0
|
||||
|
||||
summary_histogram.max = np.nan
|
||||
summary_histogram.min = np.nan
|
||||
summary_histogram.sum = 0
|
||||
|
||||
return
|
||||
|
||||
counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max))
|
||||
|
||||
for ind, count in enumerate(counts):
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = edges[ind]
|
||||
bucket.width = edges[ind + 1] - edges[ind]
|
||||
bucket.count = count
|
||||
|
||||
summary_histogram.max = tensor_max
|
||||
summary_histogram.min = tensor_min
|
||||
summary_histogram.sum = tensor_sum
|
||||
|
||||
|
||||
def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
|
||||
"""
|
||||
Package the image summary.
|
||||
|
|
|
@ -23,6 +23,7 @@ from ._summary_adapter import SummaryType, package_summary_event, save_summary_d
|
|||
FORMAT_SCALAR_STR = "Scalar"
|
||||
FORMAT_TENSOR_STR = "Tensor"
|
||||
FORMAT_IMAGE_STR = "Image"
|
||||
FORMAT_HISTOGRAM_STR = "Histogram"
|
||||
FORMAT_BEGIN_SLICE = "[:"
|
||||
FORMAT_END_SLICE = "]"
|
||||
|
||||
|
@ -95,6 +96,8 @@ def _parse_tag_format(tag: str):
|
|||
summary_type = SummaryType.TENSOR
|
||||
elif type_str == FORMAT_IMAGE_STR:
|
||||
summary_type = SummaryType.IMAGE
|
||||
elif type_str == FORMAT_HISTOGRAM_STR:
|
||||
summary_type = SummaryType.HISTOGRAM
|
||||
else:
|
||||
logger.error("The tag(%s) type is invalid.", tag)
|
||||
summary_type = SummaryType.INVALID
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Summary reader."""
|
||||
import struct
|
||||
|
||||
import mindspore.train.summary_pb2 as summary_pb2
|
||||
|
||||
_HEADER_SIZE = 8
|
||||
_HEADER_CRC_SIZE = 4
|
||||
_DATA_CRC_SIZE = 4
|
||||
|
||||
|
||||
class SummaryReader:
|
||||
"""Read events from summary file."""
|
||||
|
||||
def __init__(self, file_name):
|
||||
self._file_name = file_name
|
||||
self._file_handler = open(self._file_name, "rb")
|
||||
# skip version event
|
||||
self.read_event()
|
||||
|
||||
def read_event(self):
|
||||
"""Read next event."""
|
||||
file_handler = self._file_handler
|
||||
header = file_handler.read(_HEADER_SIZE)
|
||||
data_len = struct.unpack('Q', header)[0]
|
||||
file_handler.read(_HEADER_CRC_SIZE)
|
||||
event_str = file_handler.read(data_len)
|
||||
file_handler.read(_DATA_CRC_SIZE)
|
||||
summary_event = summary_pb2.Event.FromString(event_str)
|
||||
return summary_event
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test histogram summary."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
from .summary_reader import SummaryReader
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
|
||||
|
||||
LOG = logging.getLogger("test")
|
||||
LOG.setLevel(level=logging.ERROR)
|
||||
|
||||
|
||||
def _wrap_test_data(input_data: Tensor):
|
||||
"""
|
||||
Wraps test data to summary format.
|
||||
|
||||
Args:
|
||||
input_data (Tensor): Input data.
|
||||
|
||||
Returns:
|
||||
dict, the wrapped data.
|
||||
"""
|
||||
|
||||
return [{
|
||||
"name": "test_data[:Histogram]",
|
||||
"data": input_data
|
||||
}]
|
||||
|
||||
|
||||
def test_histogram_summary():
|
||||
"""Test histogram summary."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == 6
|
||||
|
||||
|
||||
def test_histogram_multi_summary():
|
||||
"""Test histogram multiple step."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
rng = np.random.RandomState(10)
|
||||
size = 50
|
||||
num_step = 5
|
||||
|
||||
for i in range(num_step):
|
||||
arr = rng.normal(size=size)
|
||||
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=i)
|
||||
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
for _ in range(num_step):
|
||||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == size
|
||||
|
||||
|
||||
def test_histogram_summary_scalar_tensor():
|
||||
"""Test histogram summary, input is a scalar tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
test_data = _wrap_test_data(Tensor(1))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == 1
|
||||
|
||||
|
||||
def test_histogram_summary_empty_tensor():
|
||||
"""Test histogram summary, input is an empty tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
test_data = _wrap_test_data(Tensor([]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
assert event.summary.value[0].histogram.count == 0
|
||||
|
||||
|
||||
def test_histogram_summary_same_value():
|
||||
"""Test histogram summary, input is an ones tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
|
||||
test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
LOG.debug(event)
|
||||
|
||||
assert len(event.summary.value[0].histogram.buckets) == 1
|
||||
|
||||
|
||||
def test_histogram_summary_high_dims():
|
||||
"""Test histogram summary, input is a 4-dimension tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
dim = 10
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
tensor_data = rng.normal(size=[dim, dim, dim, dim])
|
||||
test_data = _wrap_test_data(Tensor(tensor_data))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
LOG.debug(event)
|
||||
|
||||
assert event.summary.value[0].histogram.count == tensor_data.size
|
||||
|
||||
|
||||
def test_histogram_summary_nan_inf():
|
||||
"""Test histogram summary, input tensor has nan."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
|
||||
arr = np.ones([dim1, dim2])
|
||||
arr[0][0] = np.nan
|
||||
arr[0][1] = np.inf
|
||||
arr[0][2] = -np.inf
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
LOG.debug(event)
|
||||
|
||||
assert event.summary.value[0].histogram.nan_count == 1
|
||||
|
||||
|
||||
def test_histogram_summary_all_nan_inf():
|
||||
"""Test histogram summary, input tensor has no valid number."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
|
||||
test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
event = reader.read_event()
|
||||
LOG.debug(event)
|
||||
|
||||
histogram = event.summary.value[0].histogram
|
||||
assert histogram.nan_count == 3
|
||||
assert histogram.pos_inf_count == 1
|
||||
assert histogram.neg_inf_count == 1
|
Loading…
Reference in New Issue