forked from mindspore-Ecosystem/mindspore
!749 fix np.histogram sometimes calc very large bucket number
Merge pull request !749 from wenkai/wkmaster
This commit is contained in:
commit
6241814320
|
@ -15,6 +15,7 @@
|
|||
"""Generate the summary event which conform to proto format."""
|
||||
import time
|
||||
import socket
|
||||
import math
|
||||
from enum import Enum, unique
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -292,6 +293,36 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor):
|
|||
return summary_tensor
|
||||
|
||||
|
||||
def _calc_histogram_bins(count):
|
||||
"""
|
||||
Calculates experience-based optimal bins number for histogram.
|
||||
|
||||
There should be enough number in each bin. So we calc bin numbers according to count. For very small count(1 -
|
||||
10), we assign carefully chosen number. For large count, we tried to make sure there are 9-10 numbers in each
|
||||
bucket on average. Too many bins will slow down performance, so we set max number of bins to 90.
|
||||
|
||||
Args:
|
||||
count (int): Valid number count for the tensor.
|
||||
|
||||
Returns:
|
||||
int, number of histogram bins.
|
||||
"""
|
||||
number_per_bucket = 10
|
||||
max_bins = 90
|
||||
|
||||
if not count:
|
||||
return 1
|
||||
if count <= 5:
|
||||
return 2
|
||||
if count <= 10:
|
||||
return 3
|
||||
if count <= 880:
|
||||
# note that math.ceil(881/10) + 1 equals 90
|
||||
return int(math.ceil(count / number_per_bucket) + 1)
|
||||
|
||||
return max_bins
|
||||
|
||||
|
||||
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
|
||||
"""
|
||||
Package the histogram summary.
|
||||
|
@ -347,7 +378,8 @@ def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) ->
|
|||
|
||||
return
|
||||
|
||||
counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max))
|
||||
bin_number = _calc_histogram_bins(masked_value.count())
|
||||
counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max))
|
||||
|
||||
for ind, count in enumerate(counts):
|
||||
bucket = summary_histogram.buckets.add()
|
||||
|
|
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
|
||||
from .summary_reader import SummaryReader
|
||||
|
||||
CUR_DIR = os.getcwd()
|
||||
|
@ -139,7 +140,7 @@ def test_histogram_summary_same_value():
|
|||
event = reader.read_event()
|
||||
LOG.debug(event)
|
||||
|
||||
assert len(event.summary.value[0].histogram.buckets) == 1
|
||||
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
|
||||
|
||||
|
||||
def test_histogram_summary_high_dims():
|
||||
|
|
Loading…
Reference in New Issue