forked from mindspore-Ecosystem/mindspore
SummaryRecord as context manager
This commit is contained in:
parent
0695bb4c18
commit
32c1d558f4
|
@ -14,91 +14,74 @@
|
|||
# ============================================================================
|
||||
"""Writes events to disk in a logdir."""
|
||||
import os
|
||||
import time
|
||||
import stat
|
||||
from mindspore import log as logger
|
||||
from collections import deque
|
||||
from multiprocessing import Pool, Process, Queue, cpu_count
|
||||
|
||||
from ..._c_expression import EventWriter_
|
||||
from ._summary_adapter import package_init_event
|
||||
from ._summary_adapter import package_summary_event
|
||||
|
||||
|
||||
class _WrapEventWriter(EventWriter_):
|
||||
def _pack(result, step):
|
||||
summary_event = package_summary_event(result, step)
|
||||
return summary_event.SerializeToString()
|
||||
|
||||
|
||||
class EventWriter(Process):
|
||||
"""
|
||||
Wrap the c++ EventWriter object.
|
||||
Creates a `EventWriter` and write event to file.
|
||||
|
||||
Args:
|
||||
full_file_name (str): Include directory and file name.
|
||||
filepath (str): Summary event file path and file name.
|
||||
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
|
||||
"""
|
||||
def __init__(self, full_file_name):
|
||||
if full_file_name is not None:
|
||||
EventWriter_.__init__(self, full_file_name)
|
||||
|
||||
def __init__(self, filepath: str, flush_interval: int) -> None:
|
||||
super().__init__()
|
||||
with open(filepath, 'w'):
|
||||
os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
|
||||
self._writer = EventWriter_(filepath)
|
||||
self._queue = Queue(cpu_count() * 2)
|
||||
self.start()
|
||||
|
||||
class EventRecord:
|
||||
"""
|
||||
Creates a `EventFileWriter` and write event to file.
|
||||
def run(self):
|
||||
|
||||
Args:
|
||||
full_file_name (str): Summary event file path and file name.
|
||||
flush_time (int): The flush seconds to flush the pending events to disk. Default: 120.
|
||||
"""
|
||||
def __init__(self, full_file_name: str, flush_time: int = 120):
|
||||
self.full_file_name = full_file_name
|
||||
with Pool() as pool:
|
||||
deq = deque()
|
||||
while True:
|
||||
while deq and deq[0].ready():
|
||||
self._writer.Write(deq.popleft().get())
|
||||
|
||||
# The first event will be flushed immediately.
|
||||
self.flush_time = flush_time
|
||||
self.next_flush_time = 0
|
||||
if not self._queue.empty():
|
||||
action, data = self._queue.get()
|
||||
if action == 'WRITE':
|
||||
if not isinstance(data, (str, bytes)):
|
||||
deq.append(pool.apply_async(_pack, data))
|
||||
else:
|
||||
self._writer.Write(data)
|
||||
elif action == 'FLUSH':
|
||||
self._writer.Flush()
|
||||
elif action == 'END':
|
||||
break
|
||||
for res in deq:
|
||||
self._writer.Write(res.get())
|
||||
|
||||
# create event write object
|
||||
self.event_writer = self._create_event_file()
|
||||
self._init_event_file()
|
||||
self._writer.Shut()
|
||||
|
||||
# count the events
|
||||
self.event_count = 0
|
||||
def write(self, data) -> None:
|
||||
"""
|
||||
Write the event to file.
|
||||
|
||||
def _create_event_file(self):
|
||||
"""Create the event write file."""
|
||||
with open(self.full_file_name, 'w'):
|
||||
os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
# create c++ event write object
|
||||
event_writer = _WrapEventWriter(self.full_file_name)
|
||||
return event_writer
|
||||
|
||||
def _init_event_file(self):
|
||||
"""Send the init event to file."""
|
||||
self.event_writer.Write((package_init_event()).SerializeToString())
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def write_event_to_file(self, event_str):
|
||||
"""Write the event to file."""
|
||||
self.event_writer.Write(event_str)
|
||||
|
||||
def get_data_count(self):
|
||||
"""Return the event count."""
|
||||
return self.event_count
|
||||
|
||||
def flush_cycle(self):
|
||||
"""Flush file by timer."""
|
||||
self.event_count = self.event_count + 1
|
||||
# Flush the event writer every so often.
|
||||
now = int(time.time())
|
||||
if now > self.next_flush_time:
|
||||
self.flush()
|
||||
# update the flush time
|
||||
self.next_flush_time = now + self.flush_time
|
||||
|
||||
def count_event(self):
|
||||
"""Count event."""
|
||||
logger.debug("Write the event count is %r", self.event_count)
|
||||
self.event_count = self.event_count + 1
|
||||
return self.event_count
|
||||
Args:
|
||||
data (Optional[str, Tuple[list, int]]): The data to write.
|
||||
"""
|
||||
self._queue.put(('WRITE', data))
|
||||
|
||||
def flush(self):
|
||||
"""Flush the event file to disk."""
|
||||
self.event_writer.Flush()
|
||||
"""Flush the writer."""
|
||||
self._queue.put(('FLUSH', None))
|
||||
|
||||
def close(self):
|
||||
"""Flush the event file to disk and close the file."""
|
||||
self.flush()
|
||||
self.event_writer.Shut()
|
||||
def close(self) -> None:
|
||||
"""Close the writer."""
|
||||
self._queue.put(('END', None))
|
||||
self.join()
|
||||
|
|
|
@ -13,17 +13,17 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate the summary event which conform to proto format."""
|
||||
import time
|
||||
import socket
|
||||
import math
|
||||
from enum import Enum, unique
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mindspore import log as logger
|
||||
from ..summary_pb2 import Event
|
||||
from ..anf_ir_pb2 import ModelProto, DataType
|
||||
|
||||
from ..._checkparam import _check_str_by_regular
|
||||
from ..anf_ir_pb2 import DataType, ModelProto
|
||||
from ..summary_pb2 import Event
|
||||
|
||||
# define the MindSpore image format
|
||||
MS_IMAGE_TENSOR_FORMAT = 'NCHW'
|
||||
|
@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary."
|
|||
# Set the init event of version and mark
|
||||
EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:"
|
||||
EVENT_FILE_INIT_VERSION = 1
|
||||
# cache the summary data dict
|
||||
# {id: SummaryData}
|
||||
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
|
||||
g_summary_data_dict = {}
|
||||
|
||||
def save_summary_data(data_id, data):
|
||||
"""Save the global summary cache."""
|
||||
global g_summary_data_dict
|
||||
g_summary_data_dict[data_id] = data
|
||||
|
||||
|
||||
def del_summary_data(data_id):
|
||||
"""Save the global summary cache."""
|
||||
global g_summary_data_dict
|
||||
if data_id in g_summary_data_dict:
|
||||
del g_summary_data_dict[data_id]
|
||||
else:
|
||||
logger.warning("Can't del the data because data_id(%r) "
|
||||
"does not have data in g_summary_data_dict", data_id)
|
||||
|
||||
def get_summary_data(data_id):
|
||||
"""Save the global summary cache."""
|
||||
ret = None
|
||||
global g_summary_data_dict
|
||||
if data_id in g_summary_data_dict:
|
||||
ret = g_summary_data_dict.get(data_id)
|
||||
else:
|
||||
logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id)
|
||||
return ret
|
||||
|
||||
@unique
|
||||
class SummaryType(Enum):
|
||||
"""
|
||||
Summary type.
|
||||
|
||||
Args:
|
||||
SCALAR (Number): Summary Scalar enum.
|
||||
TENSOR (Number): Summary TENSOR enum.
|
||||
IMAGE (Number): Summary image enum.
|
||||
GRAPH (Number): Summary graph enum.
|
||||
HISTOGRAM (Number): Summary histogram enum.
|
||||
INVALID (Number): Unknow type.
|
||||
"""
|
||||
SCALAR = 1 # Scalar summary
|
||||
TENSOR = 2 # Tensor summary
|
||||
IMAGE = 3 # Image summary
|
||||
GRAPH = 4 # graph
|
||||
HISTOGRAM = 5 # Histogram Summary
|
||||
INVALID = 0xFF # unknow type
|
||||
|
||||
|
||||
def get_event_file_name(prefix, suffix):
|
||||
|
@ -138,7 +89,7 @@ def package_graph_event(data):
|
|||
return graph_event
|
||||
|
||||
|
||||
def package_summary_event(data_id, step):
|
||||
def package_summary_event(data_list, step):
|
||||
"""
|
||||
Package the summary to event protobuffer.
|
||||
|
||||
|
@ -149,50 +100,37 @@ def package_summary_event(data_id, step):
|
|||
Returns:
|
||||
Summary, the summary event.
|
||||
"""
|
||||
data_list = get_summary_data(data_id)
|
||||
if data_list is None:
|
||||
logger.error("The step(%r) does not have record data.", step)
|
||||
del_summary_data(data_id)
|
||||
# create the event of summary
|
||||
summary_event = Event()
|
||||
summary = summary_event.summary
|
||||
summary_event.wall_time = time.time()
|
||||
summary_event.step = int(step)
|
||||
|
||||
for value in data_list:
|
||||
tag = value["name"]
|
||||
summary_type = value["_type"]
|
||||
data = value["data"]
|
||||
summary_type = value["type"]
|
||||
tag = value["name"]
|
||||
|
||||
logger.debug("Now process %r summary, tag = %r", summary_type, tag)
|
||||
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
# get the summary type and parse the tag
|
||||
if summary_type is SummaryType.SCALAR:
|
||||
logger.debug("Now process Scalar summary, tag = %r", tag)
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
if summary_type == 'Scalar':
|
||||
summary_value.scalar_value = _get_scalar_summary(tag, data)
|
||||
elif summary_type is SummaryType.TENSOR:
|
||||
logger.debug("Now process Tensor summary, tag = %r", tag)
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
elif summary_type == 'Tensor':
|
||||
summary_tensor = summary_value.tensor
|
||||
_get_tensor_summary(tag, data, summary_tensor)
|
||||
elif summary_type is SummaryType.IMAGE:
|
||||
logger.debug("Now process Image summary, tag = %r", tag)
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
elif summary_type == 'Image':
|
||||
summary_image = summary_value.image
|
||||
_get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
|
||||
elif summary_type is SummaryType.HISTOGRAM:
|
||||
logger.debug("Now process Histogram summary, tag = %r", tag)
|
||||
summary_value = summary.value.add()
|
||||
summary_value.tag = tag
|
||||
elif summary_type == 'Histogram':
|
||||
summary_histogram = summary_value.histogram
|
||||
_fill_histogram_summary(tag, data, summary_histogram)
|
||||
else:
|
||||
# The data is invalid ,jump the data
|
||||
logger.error("Summary type is error, tag = %r", tag)
|
||||
continue
|
||||
logger.error("Summary type(%r) is error, tag = %r", summary_type, tag)
|
||||
|
||||
summary_event.wall_time = time.time()
|
||||
summary_event.step = int(step)
|
||||
return summary_event
|
||||
|
||||
|
||||
|
@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value):
|
|||
# So consider the dim = 1, shape = (1,) tensor is scalar
|
||||
scalar_value = np_value[0]
|
||||
if np_value.shape != (1,):
|
||||
logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value)
|
||||
logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape)
|
||||
else:
|
||||
np_list = np_value.reshape(-1).tolist()
|
||||
scalar_value = np_list[0]
|
||||
logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value)
|
||||
logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim)
|
||||
|
||||
logger.debug("The tag(%r) value is: %r", tag, scalar_value)
|
||||
return scalar_value
|
||||
|
@ -307,8 +245,7 @@ def _calc_histogram_bins(count):
|
|||
Returns:
|
||||
int, number of histogram bins.
|
||||
"""
|
||||
number_per_bucket = 10
|
||||
max_bins = 90
|
||||
max_bins, max_per_bin = 90, 10
|
||||
|
||||
if not count:
|
||||
return 1
|
||||
|
@ -318,78 +255,50 @@ def _calc_histogram_bins(count):
|
|||
return 3
|
||||
if count <= 880:
|
||||
# note that math.ceil(881/10) + 1 equals 90
|
||||
return int(math.ceil(count / number_per_bucket) + 1)
|
||||
return count // max_per_bin + 1
|
||||
|
||||
return max_bins
|
||||
|
||||
|
||||
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
|
||||
def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None:
|
||||
"""
|
||||
Package the histogram summary.
|
||||
|
||||
Args:
|
||||
tag (str): Summary tag describe.
|
||||
np_value (np.array): Summary data.
|
||||
summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data.
|
||||
np_value (np.ndarray): Summary data.
|
||||
summary (summary_pb2.Summary.Histogram): Summary histogram data.
|
||||
"""
|
||||
logger.debug("Set(%r) the histogram summary value", tag)
|
||||
# Default bucket for tensor with no valid data.
|
||||
default_bucket_left = -0.5
|
||||
default_bucket_width = 1.0
|
||||
ma_value = np.ma.masked_invalid(np_value)
|
||||
total, valid = np_value.size, ma_value.count()
|
||||
invalids = []
|
||||
for isfn in np.isnan, np.isposinf, np.isneginf:
|
||||
if total - valid > sum(invalids):
|
||||
count = np.count_nonzero(isfn(np_value))
|
||||
invalids.append(count)
|
||||
else:
|
||||
invalids.append(0)
|
||||
|
||||
if np_value.size == 0:
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = default_bucket_left
|
||||
bucket.width = default_bucket_width
|
||||
bucket.count = 0
|
||||
summary.count = total
|
||||
summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids
|
||||
if not valid:
|
||||
logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape)
|
||||
# summary.{min, max, sum} are 0s by default, no need to explicitly set
|
||||
else:
|
||||
summary.min = ma_value.min()
|
||||
summary.max = ma_value.max()
|
||||
summary.sum = ma_value.sum()
|
||||
bins = _calc_histogram_bins(valid)
|
||||
range_ = summary.min, summary.max
|
||||
hists, edges = np.histogram(np_value, bins=bins, range=range_)
|
||||
|
||||
summary_histogram.nan_count = 0
|
||||
summary_histogram.pos_inf_count = 0
|
||||
summary_histogram.neg_inf_count = 0
|
||||
|
||||
summary_histogram.max = 0
|
||||
summary_histogram.min = 0
|
||||
summary_histogram.sum = 0
|
||||
|
||||
summary_histogram.count = 0
|
||||
|
||||
return
|
||||
|
||||
summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value))
|
||||
summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value))
|
||||
summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value))
|
||||
summary_histogram.count = np_value.size
|
||||
|
||||
masked_value = np.ma.masked_invalid(np_value)
|
||||
tensor_max = masked_value.max()
|
||||
tensor_min = masked_value.min()
|
||||
tensor_sum = masked_value.sum()
|
||||
|
||||
# No valid value in tensor.
|
||||
if tensor_max is np.ma.masked:
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = default_bucket_left
|
||||
bucket.width = default_bucket_width
|
||||
bucket.count = 0
|
||||
|
||||
summary_histogram.max = np.nan
|
||||
summary_histogram.min = np.nan
|
||||
summary_histogram.sum = 0
|
||||
|
||||
return
|
||||
|
||||
bin_number = _calc_histogram_bins(masked_value.count())
|
||||
counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max))
|
||||
|
||||
for ind, count in enumerate(counts):
|
||||
bucket = summary_histogram.buckets.add()
|
||||
bucket.left = edges[ind]
|
||||
bucket.width = edges[ind + 1] - edges[ind]
|
||||
bucket.count = count
|
||||
|
||||
summary_histogram.max = tensor_max
|
||||
summary_histogram.min = tensor_min
|
||||
summary_histogram.sum = tensor_sum
|
||||
for hist, edge1, edge2 in zip(hists, edges, edges[1:]):
|
||||
bucket = summary.buckets.add()
|
||||
bucket.width = edge2 - edge1
|
||||
bucket.count = hist
|
||||
bucket.left = edge1
|
||||
|
||||
|
||||
def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
|
||||
|
@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
|
|||
"""
|
||||
logger.debug("Set(%r) the image summary value", tag)
|
||||
if np_value.ndim != 4:
|
||||
logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value)
|
||||
logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim)
|
||||
|
||||
# convert the tensor format
|
||||
tensor = _convert_image_format(np_value, input_format)
|
||||
|
@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
|
|||
"""
|
||||
out_tensor = None
|
||||
if np_tensor.ndim != len(input_format):
|
||||
logger.error("The tensor(%r) can't convert the format(%r) because dim not same",
|
||||
np_tensor, input_format)
|
||||
logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim,
|
||||
input_format)
|
||||
return out_tensor
|
||||
|
||||
input_format = input_format.upper()
|
||||
|
@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
|
|||
|
||||
# check the tensor format
|
||||
if tensor.ndim != 4 or tensor.shape[1] != 3:
|
||||
logger.error("The image tensor(%r) is not 'NCHW' format", tensor)
|
||||
logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape)
|
||||
return out_canvas
|
||||
|
||||
# expand the N
|
||||
|
|
|
@ -1,308 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Schedule the event writer process."""
|
||||
import multiprocessing as mp
|
||||
from enum import Enum, unique
|
||||
from mindspore import log as logger
|
||||
from ..._c_expression import Tensor
|
||||
from ._summary_adapter import SummaryType, package_summary_event, save_summary_data
|
||||
|
||||
# define the type of summary
|
||||
FORMAT_SCALAR_STR = "Scalar"
|
||||
FORMAT_TENSOR_STR = "Tensor"
|
||||
FORMAT_IMAGE_STR = "Image"
|
||||
FORMAT_HISTOGRAM_STR = "Histogram"
|
||||
FORMAT_BEGIN_SLICE = "[:"
|
||||
FORMAT_END_SLICE = "]"
|
||||
|
||||
# cache the summary data dict
|
||||
# {id: SummaryData}
|
||||
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
|
||||
g_summary_data_id = 0
|
||||
g_summary_data_dict = {}
|
||||
# cache the summary data file
|
||||
g_summary_writer_id = 0
|
||||
g_summary_file = {}
|
||||
|
||||
|
||||
@unique
|
||||
class ScheduleMethod(Enum):
|
||||
"""Schedule method type."""
|
||||
FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
|
||||
TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
|
||||
CACHE_DATA = 2 # Cache data util have idle worker to process it
|
||||
|
||||
|
||||
@unique
|
||||
class WorkerStatus(Enum):
|
||||
"""Worker status."""
|
||||
WORKER_INIT = 0 # data is exist but not process
|
||||
WORKER_PROCESSING = 1 # data is processing
|
||||
WORKER_PROCESSED = 2 # data already processed
|
||||
|
||||
|
||||
def _parse_tag_format(tag: str):
|
||||
"""
|
||||
Parse the tag.
|
||||
|
||||
Args:
|
||||
tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
|
||||
|
||||
Returns:
|
||||
Tuple, (SummaryType, summary_tag).
|
||||
"""
|
||||
|
||||
summary_type = SummaryType.INVALID
|
||||
summary_tag = tag
|
||||
if tag is None:
|
||||
logger.error("The tag is None")
|
||||
return summary_type, summary_tag
|
||||
|
||||
# search the slice
|
||||
slice_begin = FORMAT_BEGIN_SLICE
|
||||
slice_end = FORMAT_END_SLICE
|
||||
index = tag.rfind(slice_begin)
|
||||
if index is -1:
|
||||
logger.error("The tag(%s) have not the key slice.", tag)
|
||||
return summary_type, summary_tag
|
||||
|
||||
# slice the tag
|
||||
summary_tag = tag[:index]
|
||||
|
||||
# check the slice end
|
||||
if tag[-1:] != slice_end:
|
||||
logger.error("The tag(%s) end format is error", tag)
|
||||
return summary_type, summary_tag
|
||||
|
||||
# check the type
|
||||
type_str = tag[index + 2: -1]
|
||||
logger.debug("The summary_tag is = %r", summary_tag)
|
||||
logger.debug("The type_str value is = %r", type_str)
|
||||
if type_str == FORMAT_SCALAR_STR:
|
||||
summary_type = SummaryType.SCALAR
|
||||
elif type_str == FORMAT_TENSOR_STR:
|
||||
summary_type = SummaryType.TENSOR
|
||||
elif type_str == FORMAT_IMAGE_STR:
|
||||
summary_type = SummaryType.IMAGE
|
||||
elif type_str == FORMAT_HISTOGRAM_STR:
|
||||
summary_type = SummaryType.HISTOGRAM
|
||||
else:
|
||||
logger.error("The tag(%s) type is invalid.", tag)
|
||||
summary_type = SummaryType.INVALID
|
||||
|
||||
return summary_type, summary_tag
|
||||
|
||||
|
||||
class SummaryDataManager:
|
||||
"""Manage the summary global data cache."""
|
||||
def __init__(self):
|
||||
global g_summary_data_dict
|
||||
self.size = len(g_summary_data_dict)
|
||||
|
||||
@classmethod
|
||||
def summary_data_save(cls, data):
|
||||
"""Save the global summary cache."""
|
||||
global g_summary_data_id
|
||||
data_id = g_summary_data_id
|
||||
save_summary_data(data_id, data)
|
||||
g_summary_data_id += 1
|
||||
return data_id
|
||||
|
||||
@classmethod
|
||||
def summary_file_set(cls, event_writer):
|
||||
"""Support the many event_writer."""
|
||||
global g_summary_file, g_summary_writer_id
|
||||
g_summary_writer_id += 1
|
||||
g_summary_file[g_summary_writer_id] = event_writer
|
||||
return g_summary_writer_id
|
||||
|
||||
@classmethod
|
||||
def summary_file_get(cls, writer_id=1):
|
||||
ret = None
|
||||
global g_summary_file
|
||||
if writer_id in g_summary_file:
|
||||
ret = g_summary_file.get(writer_id)
|
||||
return ret
|
||||
|
||||
|
||||
class WorkerScheduler:
|
||||
"""
|
||||
Create worker and schedule data to worker.
|
||||
|
||||
Args:
|
||||
writer_id (int): The index of writer.
|
||||
"""
|
||||
def __init__(self, writer_id):
|
||||
# Create the process of write event file
|
||||
self.write_lock = mp.Lock()
|
||||
# Schedule info for all worker
|
||||
# Format: {worker: (step, WorkerStatus)}
|
||||
self.schedule_table = {}
|
||||
# write id
|
||||
self.writer_id = writer_id
|
||||
self.has_graph = False
|
||||
|
||||
def dispatch(self, step, data):
|
||||
"""
|
||||
Select schedule strategy and dispatch data.
|
||||
|
||||
Args:
|
||||
step (Number): The number of step index.
|
||||
data (Object): The data of recode for summary.
|
||||
|
||||
Retruns:
|
||||
bool, run successfully or not.
|
||||
"""
|
||||
# save the data to global cache , convert the tensor to numpy
|
||||
result, size, data = self._data_convert(data)
|
||||
if result is False:
|
||||
logger.error("The step(%r) summary data(%r) is invalid.", step, size)
|
||||
return False
|
||||
|
||||
data_id = SummaryDataManager.summary_data_save(data)
|
||||
self._start_worker(step, data_id)
|
||||
return True
|
||||
|
||||
def _start_worker(self, step, data_id):
|
||||
"""
|
||||
Start worker.
|
||||
|
||||
Args:
|
||||
step (Number): The index of recode.
|
||||
data_id (str): The id of work.
|
||||
|
||||
Return:
|
||||
bool, run successfully or not.
|
||||
"""
|
||||
# assign the worker
|
||||
policy = self._make_policy()
|
||||
if policy == ScheduleMethod.TEMP_WORKER:
|
||||
worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id)
|
||||
# update the schedule table
|
||||
self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT)
|
||||
# start the worker
|
||||
worker.start()
|
||||
else:
|
||||
logger.error("Do not support the other scheduler policy now.")
|
||||
|
||||
# update the scheduler infor
|
||||
self._update_scheduler()
|
||||
return True
|
||||
|
||||
def _data_convert(self, data_list):
|
||||
"""Convert the data."""
|
||||
if data_list is None:
|
||||
logger.warning("The step does not have record data.")
|
||||
return False, 0, None
|
||||
|
||||
# convert the summary to numpy
|
||||
size = 0
|
||||
for v_dict in data_list:
|
||||
tag = v_dict["name"]
|
||||
data = v_dict["data"]
|
||||
# confirm the data is valid
|
||||
summary_type, summary_tag = _parse_tag_format(tag)
|
||||
if summary_type == SummaryType.INVALID:
|
||||
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
|
||||
return False, 0, None
|
||||
if isinstance(data, Tensor):
|
||||
# get the summary type and parse the tag
|
||||
v_dict["name"] = summary_tag
|
||||
v_dict["type"] = summary_type
|
||||
v_dict["data"] = data.asnumpy()
|
||||
size += v_dict["data"].size
|
||||
else:
|
||||
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
|
||||
return False, 0, None
|
||||
|
||||
return True, size, data_list
|
||||
|
||||
def _update_scheduler(self):
|
||||
"""Check the worker status and update schedule table."""
|
||||
workers = list(self.schedule_table.keys())
|
||||
for worker in workers:
|
||||
if not worker.is_alive():
|
||||
# update the table
|
||||
worker.join()
|
||||
del self.schedule_table[worker]
|
||||
|
||||
def close(self):
|
||||
"""Confirm all worker is end."""
|
||||
workers = self.schedule_table.keys()
|
||||
for worker in workers:
|
||||
if worker.is_alive():
|
||||
worker.join()
|
||||
|
||||
def _make_policy(self):
|
||||
"""Select the schedule strategy by data."""
|
||||
# now only support the temp worker
|
||||
return ScheduleMethod.TEMP_WORKER
|
||||
|
||||
|
||||
class SummaryDataProcess(mp.Process):
|
||||
"""
|
||||
Process that consume the summarydata.
|
||||
|
||||
Args:
|
||||
step (int): The index of step.
|
||||
data_id (int): The index of summary data.
|
||||
write_lock (Lock): The process lock for writer same file.
|
||||
writer_id (int): The index of writer.
|
||||
"""
|
||||
def __init__(self, step, data_id, write_lock, writer_id):
|
||||
super(SummaryDataProcess, self).__init__()
|
||||
self.daemon = True
|
||||
self.writer_id = writer_id
|
||||
self.writer = SummaryDataManager.summary_file_get(self.writer_id)
|
||||
if self.writer is None:
|
||||
logger.error("The writer_id(%r) does not have writer", writer_id)
|
||||
self.step = step
|
||||
self.data_id = data_id
|
||||
self.write_lock = write_lock
|
||||
self.name = "SummaryDataConsumer_" + str(self.step)
|
||||
|
||||
def run(self):
|
||||
"""The consumer is process the step data and exit."""
|
||||
# convert the data to event
|
||||
# All exceptions need to be caught and end the queue
|
||||
try:
|
||||
logger.debug("process(%r) process a data(%r)", self.name, self.step)
|
||||
# package the summary event
|
||||
summary_event = package_summary_event(self.data_id, self.step)
|
||||
# send the event to file
|
||||
self._write_summary(summary_event)
|
||||
except Exception as e:
|
||||
logger.error("Summary data mq consumer exception occurred, value = %r", e)
|
||||
|
||||
def _write_summary(self, summary_event):
|
||||
"""
|
||||
Write the summary to event file.
|
||||
|
||||
Note:
|
||||
The write record format:
|
||||
1 uint64 : data length.
|
||||
2 uint32 : mask crc value of data length.
|
||||
3 bytes : data.
|
||||
4 uint32 : mask crc value of data.
|
||||
|
||||
Args:
|
||||
summary_event (Event): The summary event of proto.
|
||||
|
||||
"""
|
||||
event_str = summary_event.SerializeToString()
|
||||
self.write_lock.acquire()
|
||||
self.writer.write_event_to_file(event_str)
|
||||
self.writer.flush()
|
||||
self.write_lock.release()
|
|
@ -14,17 +14,22 @@
|
|||
# ============================================================================
|
||||
"""Record the summary event."""
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from mindspore import log as logger
|
||||
from ._summary_scheduler import WorkerScheduler, SummaryDataManager
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event
|
||||
from ._event_writer import EventRecord
|
||||
from .._utils import _make_directory
|
||||
from ..._checkparam import _check_str_by_regular
|
||||
|
||||
from mindspore import log as logger
|
||||
|
||||
from ..._c_expression import Tensor
|
||||
from ..._checkparam import _check_str_by_regular
|
||||
from .._utils import _make_directory
|
||||
from ._event_writer import EventWriter
|
||||
from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event
|
||||
|
||||
# for the moment, this lock is for caution's sake,
|
||||
# there are actually no any concurrencies happening.
|
||||
_summary_lock = threading.Lock()
|
||||
# cache the summary data
|
||||
_summary_tensor_cache = {}
|
||||
_summary_lock = threading.Lock()
|
||||
|
||||
|
||||
def _cache_summary_tensor_data(summary):
|
||||
|
@ -34,14 +39,18 @@ def _cache_summary_tensor_data(summary):
|
|||
Args:
|
||||
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
|
||||
"""
|
||||
_summary_lock.acquire()
|
||||
if "SummaryRecord" in _summary_tensor_cache:
|
||||
for record in summary:
|
||||
_summary_tensor_cache["SummaryRecord"].append(record)
|
||||
else:
|
||||
_summary_tensor_cache["SummaryRecord"] = summary
|
||||
_summary_lock.release()
|
||||
return True
|
||||
with _summary_lock:
|
||||
for item in summary:
|
||||
_summary_tensor_cache[item['name']] = item['data']
|
||||
return True
|
||||
|
||||
|
||||
def _get_summary_tensor_data():
|
||||
global _summary_tensor_cache
|
||||
with _summary_lock:
|
||||
data = _summary_tensor_cache
|
||||
_summary_tensor_cache = {}
|
||||
return data
|
||||
|
||||
|
||||
class SummaryRecord:
|
||||
|
@ -53,7 +62,7 @@ class SummaryRecord:
|
|||
It writes the event log to a file by executing the record method. In addition,
|
||||
if the SummaryRecord object is created and the summary operator is used in the network,
|
||||
even if the record method is not called, the event in the cache will be written to the
|
||||
file at the end of execution or when the summary is closed.
|
||||
file at the end of execution. Make sure to close the SummaryRecord object at the end.
|
||||
|
||||
Args:
|
||||
log_dir (str): The log_dir is a directory location to save the summary.
|
||||
|
@ -68,9 +77,10 @@ class SummaryRecord:
|
|||
RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname.
|
||||
|
||||
Examples:
|
||||
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
||||
>>> file_prefix="xxx_", file_suffix="_yyy")
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
log_dir,
|
||||
queue_max_size=0,
|
||||
|
@ -101,26 +111,36 @@ class SummaryRecord:
|
|||
|
||||
self.prefix = file_prefix
|
||||
self.suffix = file_suffix
|
||||
self.network = network
|
||||
self.has_graph = False
|
||||
self._closed = False
|
||||
|
||||
# create the summary writer file
|
||||
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
|
||||
if self.log_path[-1:] == '/':
|
||||
self.full_file_name = self.log_path + self.event_file_name
|
||||
else:
|
||||
self.full_file_name = self.log_path + '/' + self.event_file_name
|
||||
|
||||
try:
|
||||
self.full_file_name = os.path.realpath(self.full_file_name)
|
||||
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
|
||||
except Exception as ex:
|
||||
raise RuntimeError(ex)
|
||||
self.event_writer = EventRecord(self.full_file_name, self.flush_time)
|
||||
self.writer_id = SummaryDataManager.summary_file_set(self.event_writer)
|
||||
self.worker_scheduler = WorkerScheduler(self.writer_id)
|
||||
|
||||
self.step = 0
|
||||
self._closed = False
|
||||
self.network = network
|
||||
self.has_graph = False
|
||||
self._event_writer = None
|
||||
|
||||
def _init_event_writer(self):
|
||||
"""Init event writer and write metadata."""
|
||||
event_writer = EventWriter(self.full_file_name, self.flush_time)
|
||||
event_writer.write(package_init_event().SerializeToString())
|
||||
return event_writer
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager."""
|
||||
if not self._event_writer:
|
||||
self._event_writer = self._init_event_writer()
|
||||
if self._closed:
|
||||
raise ValueError('SummaryRecord has been closed.')
|
||||
return self
|
||||
|
||||
def __exit__(self, extype, exvalue, traceback):
|
||||
"""Exit the context manager."""
|
||||
self.close()
|
||||
|
||||
def record(self, step, train_network=None):
|
||||
"""
|
||||
|
@ -131,9 +151,8 @@ class SummaryRecord:
|
|||
train_network (Cell): The network that called the callback.
|
||||
|
||||
Examples:
|
||||
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
||||
>>> file_prefix="xxx_", file_suffix="_yyy")
|
||||
>>> summary_record.record(step=2)
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> summary_record.record(step=2)
|
||||
|
||||
Returns:
|
||||
bool, whether the record process is successful or not.
|
||||
|
@ -145,42 +164,37 @@ class SummaryRecord:
|
|||
if not isinstance(step, int) or isinstance(step, bool):
|
||||
raise ValueError("`step` should be int")
|
||||
# Set the current summary of train step
|
||||
self.step = step
|
||||
if not self._event_writer:
|
||||
self._event_writer = self._init_event_writer()
|
||||
logger.warning('SummaryRecord should be used as context manager for a with statement.')
|
||||
|
||||
if self.network is not None and self.has_graph is False:
|
||||
if self.network is not None and not self.has_graph:
|
||||
graph_proto = self.network.get_func_graph_proto()
|
||||
if graph_proto is None and train_network is not None:
|
||||
graph_proto = train_network.get_func_graph_proto()
|
||||
if graph_proto is None:
|
||||
logger.error("Failed to get proto for graph")
|
||||
else:
|
||||
self.event_writer.write_event_to_file(
|
||||
package_graph_event(graph_proto).SerializeToString())
|
||||
self.event_writer.flush()
|
||||
self._event_writer.write(package_graph_event(graph_proto).SerializeToString())
|
||||
self.has_graph = True
|
||||
data = _summary_tensor_cache.get("SummaryRecord")
|
||||
if data is None:
|
||||
if not _summary_tensor_cache:
|
||||
return True
|
||||
|
||||
data = _summary_tensor_cache.get("SummaryRecord")
|
||||
if data is None:
|
||||
logger.error("The step(%r) does not have record data.", self.step)
|
||||
data = _get_summary_tensor_data()
|
||||
if not data:
|
||||
logger.error("The step(%r) does not have record data.", step)
|
||||
return False
|
||||
if self.queue_max_size > 0 and len(data) > self.queue_max_size:
|
||||
logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data),
|
||||
self.queue_max_size)
|
||||
|
||||
# clean the data of cache
|
||||
del _summary_tensor_cache["SummaryRecord"]
|
||||
|
||||
# process the data
|
||||
self.worker_scheduler.dispatch(self.step, data)
|
||||
|
||||
# count & flush
|
||||
self.event_writer.count_event()
|
||||
self.event_writer.flush_cycle()
|
||||
|
||||
logger.debug("Send the summary data to scheduler for saving, step = %d", self.step)
|
||||
result = self._data_convert(data)
|
||||
if not result:
|
||||
logger.error("The step(%r) summary data is invalid.", step)
|
||||
return False
|
||||
self._event_writer.write((result, step))
|
||||
logger.debug("Send the summary data to scheduler for saving, step = %d", step)
|
||||
return True
|
||||
|
||||
@property
|
||||
|
@ -189,14 +203,13 @@ class SummaryRecord:
|
|||
Get the full path of the log file.
|
||||
|
||||
Examples:
|
||||
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
||||
>>> file_prefix="xxx_", file_suffix="_yyy")
|
||||
>>> print(summary_record.log_dir)
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> print(summary_record.log_dir)
|
||||
|
||||
Returns:
|
||||
String, the full path of log file.
|
||||
"""
|
||||
return self.event_writer.full_file_name
|
||||
return self.full_file_name
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
|
@ -205,39 +218,64 @@ class SummaryRecord:
|
|||
Call it to make sure that all pending events have been written to disk.
|
||||
|
||||
Examples:
|
||||
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
||||
>>> file_prefix="xxx_", file_suffix="_yyy")
|
||||
>>> summary_record.flush()
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> summary_record.flush()
|
||||
"""
|
||||
if self._closed:
|
||||
logger.error("The record writer is closed and can not flush.")
|
||||
else:
|
||||
self.event_writer.flush()
|
||||
elif self._event_writer:
|
||||
self._event_writer.flush()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Flush all events and close summary records.
|
||||
Flush all events and close summary records. Please use with statement to autoclose.
|
||||
|
||||
Examples:
|
||||
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
|
||||
>>> file_prefix="xxx_", file_suffix="_yyy")
|
||||
>>> summary_record.close()
|
||||
>>> with SummaryRecord(log_dir="/opt/log", file_prefix="xxx_", file_suffix="_yyy") as summary_record:
|
||||
>>> pass # summary_record autoclosed
|
||||
"""
|
||||
if not self._closed:
|
||||
self._check_data_before_close()
|
||||
self.worker_scheduler.close()
|
||||
if not self._closed and self._event_writer:
|
||||
# event writer flush and close
|
||||
self.event_writer.close()
|
||||
self._event_writer.close()
|
||||
self._closed = True
|
||||
|
||||
def __del__(self):
|
||||
"""Process exit is called."""
|
||||
if hasattr(self, "worker_scheduler"):
|
||||
if self.worker_scheduler:
|
||||
self.close()
|
||||
def __del__(self) -> None:
|
||||
self.close()
|
||||
|
||||
def _check_data_before_close(self):
|
||||
"Check whether there is any data in the cache, and if so, call record"
|
||||
data = _summary_tensor_cache.get("SummaryRecord")
|
||||
if data is not None:
|
||||
self.record(self.step)
|
||||
def _data_convert(self, summary):
|
||||
"""Convert the data."""
|
||||
# convert the summary to numpy
|
||||
result = []
|
||||
for name, data in summary.items():
|
||||
# confirm the data is valid
|
||||
summary_tag, summary_type = SummaryRecord._parse_from(name)
|
||||
if summary_tag is None:
|
||||
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
||||
return None
|
||||
if isinstance(data, Tensor):
|
||||
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
|
||||
else:
|
||||
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_from(name: str = None):
|
||||
"""
|
||||
Parse the tag and type from name.
|
||||
|
||||
Args:
|
||||
name (str): Format: TAG[:TYPE].
|
||||
|
||||
Returns:
|
||||
Tuple, (summary_tag, summary_type).
|
||||
"""
|
||||
if name is None:
|
||||
logger.error("The name is None")
|
||||
return None, None
|
||||
match = re.match(r'(.+)\[:(.+)\]', name)
|
||||
if match:
|
||||
return match.groups()
|
||||
logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name)
|
||||
return None, None
|
||||
|
|
|
@ -53,14 +53,13 @@ def me_train_tensor(net, input_np, label_np, epoch_size=2):
|
|||
_network = wrap.WithLossCell(net, loss)
|
||||
_train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt))
|
||||
_train_net.set_train()
|
||||
summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net)
|
||||
for epoch in range(0, epoch_size):
|
||||
print(f"epoch %d" % (epoch))
|
||||
output = _train_net(Tensor(input_np), Tensor(label_np))
|
||||
summary_writer.record(i)
|
||||
print("********output***********")
|
||||
print(output.asnumpy())
|
||||
summary_writer.close()
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) as summary_writer:
|
||||
for epoch in range(0, epoch_size):
|
||||
print(f"epoch %d" % (epoch))
|
||||
output = _train_net(Tensor(input_np), Tensor(label_np))
|
||||
summary_writer.record(i)
|
||||
print("********output***********")
|
||||
print(output.asnumpy())
|
||||
|
||||
|
||||
def me_infer_tensor(net, input_np):
|
||||
|
|
|
@ -91,15 +91,14 @@ def train_summary_record_scalar_for_1(test_writer, steps, fwd_x, fwd_y):
|
|||
|
||||
|
||||
def me_scalar_summary(steps, tag=None, value=None):
|
||||
test_writer = SummaryRecord(SUMMARY_DIR_ME_TEMP)
|
||||
with SummaryRecord(SUMMARY_DIR_ME_TEMP) as test_writer:
|
||||
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
|
||||
out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y)
|
||||
out_me_dict = train_summary_record_scalar_for_1(test_writer, steps, x, y)
|
||||
|
||||
test_writer.close()
|
||||
return out_me_dict
|
||||
return out_me_dict
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -106,18 +106,17 @@ def test_graph_summary_sample():
|
|||
optim = Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network)
|
||||
model.train(2, dataset)
|
||||
# step 2: create the Event
|
||||
for i in range(1, 5):
|
||||
test_writer.record(i)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
|
||||
model.train(2, dataset)
|
||||
# step 2: create the Event
|
||||
for i in range(1, 5):
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: send the event to mq
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
# step 4: accept the event and write the file
|
||||
|
||||
log.debug("finished test_graph_summary_sample")
|
||||
log.debug("finished test_graph_summary_sample")
|
||||
|
||||
|
||||
def test_graph_summary_callback():
|
||||
|
@ -127,9 +126,9 @@ def test_graph_summary_callback():
|
|||
optim = Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network)
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
|
||||
|
||||
def test_graph_summary_callback2():
|
||||
|
@ -139,6 +138,6 @@ def test_graph_summary_callback2():
|
|||
optim = Momentum(net.trainable_params(), 0.1, 0.9)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net)
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer:
|
||||
summary_cb = SummaryStep(test_writer, 1)
|
||||
model.train(2, dataset, callbacks=summary_cb)
|
||||
|
|
|
@ -52,12 +52,11 @@ def _wrap_test_data(input_data: Tensor):
|
|||
def test_histogram_summary():
|
||||
"""Test histogram summary."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -68,20 +67,18 @@ def test_histogram_summary():
|
|||
def test_histogram_multi_summary():
|
||||
"""Test histogram multiple step."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
rng = np.random.RandomState(10)
|
||||
size = 50
|
||||
num_step = 5
|
||||
rng = np.random.RandomState(10)
|
||||
size = 50
|
||||
num_step = 5
|
||||
|
||||
for i in range(num_step):
|
||||
arr = rng.normal(size=size)
|
||||
for i in range(num_step):
|
||||
arr = rng.normal(size=size)
|
||||
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=i)
|
||||
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=i)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -93,12 +90,11 @@ def test_histogram_multi_summary():
|
|||
def test_histogram_summary_scalar_tensor():
|
||||
"""Test histogram summary, input is a scalar tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
test_data = _wrap_test_data(Tensor(1))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor(1))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -109,12 +105,11 @@ def test_histogram_summary_scalar_tensor():
|
|||
def test_histogram_summary_empty_tensor():
|
||||
"""Test histogram summary, input is an empty tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
test_data = _wrap_test_data(Tensor([]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor([]))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -125,15 +120,14 @@ def test_histogram_summary_empty_tensor():
|
|||
def test_histogram_summary_same_value():
|
||||
"""Test histogram summary, input is an ones tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
|
||||
test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -146,15 +140,14 @@ def test_histogram_summary_same_value():
|
|||
def test_histogram_summary_high_dims():
|
||||
"""Test histogram summary, input is a 4-dimension tensor."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
dim = 10
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
dim = 10
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
tensor_data = rng.normal(size=[dim, dim, dim, dim])
|
||||
test_data = _wrap_test_data(Tensor(tensor_data))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
rng = np.random.RandomState(0)
|
||||
tensor_data = rng.normal(size=[dim, dim, dim, dim])
|
||||
test_data = _wrap_test_data(Tensor(tensor_data))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -167,20 +160,19 @@ def test_histogram_summary_high_dims():
|
|||
def test_histogram_summary_nan_inf():
|
||||
"""Test histogram summary, input tensor has nan."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
dim1 = 100
|
||||
dim2 = 100
|
||||
|
||||
arr = np.ones([dim1, dim2])
|
||||
arr[0][0] = np.nan
|
||||
arr[0][1] = np.inf
|
||||
arr[0][2] = -np.inf
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
arr = np.ones([dim1, dim2])
|
||||
arr[0][0] = np.nan
|
||||
arr[0][1] = np.inf
|
||||
arr[0][2] = -np.inf
|
||||
test_data = _wrap_test_data(Tensor(arr))
|
||||
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
@ -193,12 +185,11 @@ def test_histogram_summary_nan_inf():
|
|||
def test_histogram_summary_all_nan_inf():
|
||||
"""Test histogram summary, input tensor has no valid number."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
|
||||
with SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM") as test_writer:
|
||||
|
||||
test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
test_writer.close()
|
||||
test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(step=1)
|
||||
|
||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||
reader = SummaryReader(file_name)
|
||||
|
|
|
@ -74,23 +74,21 @@ def test_image_summary_sample():
|
|||
""" test_image_summary_sample """
|
||||
log.debug("begin test_image_summary_sample")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
|
||||
|
||||
# step 1: create the test data for summary
|
||||
# step 1: create the test data for summary
|
||||
|
||||
# step 2: create the Event
|
||||
for i in range(1, 5):
|
||||
test_data = get_test_data(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
test_writer.flush()
|
||||
# step 2: create the Event
|
||||
for i in range(1, 5):
|
||||
test_data = get_test_data(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
test_writer.flush()
|
||||
|
||||
# step 3: send the event to mq
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
|
||||
log.debug("finished test_image_summary_sample")
|
||||
# step 4: accept the event and write the file
|
||||
log.debug("finished test_image_summary_sample")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -174,23 +172,21 @@ def test_image_summary_train():
|
|||
|
||||
log.debug("begin test_image_summary_sample")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
|
||||
|
||||
# step 1: create the test data for summary
|
||||
# step 1: create the test data for summary
|
||||
|
||||
# step 2: create the Event
|
||||
# step 2: create the Event
|
||||
|
||||
model = get_model()
|
||||
fn = ImageSummaryCallback(test_writer)
|
||||
summary_recode = SummaryStep(fn, 1)
|
||||
model.train(2, dataset, callbacks=summary_recode)
|
||||
model = get_model()
|
||||
fn = ImageSummaryCallback(test_writer)
|
||||
summary_recode = SummaryStep(fn, 1)
|
||||
model.train(2, dataset, callbacks=summary_recode)
|
||||
|
||||
# step 3: send the event to mq
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
|
||||
log.debug("finished test_image_summary_sample")
|
||||
# step 4: accept the event and write the file
|
||||
log.debug("finished test_image_summary_sample")
|
||||
|
||||
|
||||
def test_image_summary_data():
|
||||
|
@ -209,18 +205,12 @@ def test_image_summary_data():
|
|||
|
||||
log.debug("begin test_image_summary_sample")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_IMAGE") as test_writer:
|
||||
|
||||
# step 1: create the test data for summary
|
||||
# step 1: create the test data for summary
|
||||
|
||||
# step 2: create the Event
|
||||
_cache_summary_tensor_data(test_data_list)
|
||||
test_writer.record(1)
|
||||
test_writer.flush()
|
||||
# step 2: create the Event
|
||||
_cache_summary_tensor_data(test_data_list)
|
||||
test_writer.record(1)
|
||||
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
|
||||
log.debug("finished test_image_summary_sample")
|
||||
log.debug("finished test_image_summary_sample")
|
||||
|
|
|
@ -65,22 +65,21 @@ def test_scalar_summary_sample():
|
|||
""" test_scalar_summary_sample """
|
||||
log.debug("begin test_scalar_summary_sample")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
|
||||
|
||||
# step 1: create the test data for summary
|
||||
# step 1: create the test data for summary
|
||||
|
||||
# step 2: create the Event
|
||||
for i in range(1, 500):
|
||||
test_data = get_test_data(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
# step 2: create the Event
|
||||
for i in range(1, 500):
|
||||
test_data = get_test_data(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: send the event to mq
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
# step 4: accept the event and write the file
|
||||
|
||||
log.debug("finished test_scalar_summary_sample")
|
||||
log.debug("finished test_scalar_summary_sample")
|
||||
|
||||
|
||||
def get_test_data_shape_1(step):
|
||||
|
@ -110,22 +109,21 @@ def test_scalar_summary_sample_with_shape_1():
|
|||
""" test_scalar_summary_sample_with_shape_1 """
|
||||
log.debug("begin test_scalar_summary_sample_with_shape_1")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
|
||||
|
||||
# step 1: create the test data for summary
|
||||
# step 1: create the test data for summary
|
||||
|
||||
# step 2: create the Event
|
||||
for i in range(1, 100):
|
||||
test_data = get_test_data_shape_1(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
# step 2: create the Event
|
||||
for i in range(1, 100):
|
||||
test_data = get_test_data_shape_1(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: send the event to mq
|
||||
# step 3: send the event to mq
|
||||
|
||||
# step 4: accept the event and write the file
|
||||
test_writer.close()
|
||||
# step 4: accept the event and write the file
|
||||
|
||||
log.debug("finished test_scalar_summary_sample")
|
||||
log.debug("finished test_scalar_summary_sample")
|
||||
|
||||
|
||||
# Test: test with ge
|
||||
|
@ -152,26 +150,24 @@ def test_scalar_summary_with_ge():
|
|||
log.debug("begin test_scalar_summary_with_ge")
|
||||
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
|
||||
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: close the writer
|
||||
test_writer.close()
|
||||
|
||||
log.debug("finished test_scalar_summary_with_ge")
|
||||
log.debug("finished test_scalar_summary_with_ge")
|
||||
|
||||
|
||||
# test the problem of two consecutive use cases going wrong
|
||||
|
@ -180,55 +176,52 @@ def test_scalar_summary_with_ge_2():
|
|||
log.debug("begin test_scalar_summary_with_ge_2")
|
||||
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_SCALAR") as test_writer:
|
||||
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
|
||||
# step 3: close the writer
|
||||
test_writer.close()
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
|
||||
log.debug("finished test_scalar_summary_with_ge_2")
|
||||
|
||||
log.debug("finished test_scalar_summary_with_ge_2")
|
||||
|
||||
|
||||
def test_validate():
|
||||
sr = SummaryRecord(SUMMARY_DIR)
|
||||
with SummaryRecord(SUMMARY_DIR) as sr:
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 0)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, -1)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 1.2)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, True)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, "str")
|
||||
sr.record(1)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(False)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(2.0)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record((1, 3))
|
||||
with pytest.raises(ValueError):
|
||||
sr.record([2, 3])
|
||||
with pytest.raises(ValueError):
|
||||
sr.record("str")
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(sr)
|
||||
sr.close()
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 0)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, -1)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, 1.2)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, True)
|
||||
with pytest.raises(ValueError):
|
||||
SummaryStep(sr, "str")
|
||||
sr.record(1)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(False)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(2.0)
|
||||
with pytest.raises(ValueError):
|
||||
sr.record((1, 3))
|
||||
with pytest.raises(ValueError):
|
||||
sr.record([2, 3])
|
||||
with pytest.raises(ValueError):
|
||||
sr.record("str")
|
||||
with pytest.raises(ValueError):
|
||||
sr.record(sr)
|
||||
|
||||
SummaryStep(sr, 1)
|
||||
with pytest.raises(ValueError):
|
||||
|
|
|
@ -126,23 +126,21 @@ class HistogramSummaryNet(nn.Cell):
|
|||
def run_case(net):
|
||||
""" run_case """
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR)
|
||||
with SummaryRecord(SUMMARY_DIR) as test_writer:
|
||||
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net.set_train()
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net.set_train()
|
||||
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([1.1 + random.uniform(1, 10)]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2 + random.uniform(1, 10)]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: close the writer
|
||||
test_writer.close()
|
||||
|
||||
|
||||
# Test 1: use the repeat tag
|
||||
|
|
|
@ -80,19 +80,18 @@ def test_tensor_summary_sample():
|
|||
""" test_tensor_summary_sample """
|
||||
log.debug("begin test_tensor_summary_sample")
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR")
|
||||
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_TENSOR") as test_writer:
|
||||
|
||||
# step 1: create the Event
|
||||
for i in range(1, 100):
|
||||
test_data = get_test_data(i)
|
||||
# step 1: create the Event
|
||||
for i in range(1, 100):
|
||||
test_data = get_test_data(i)
|
||||
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
_cache_summary_tensor_data(test_data)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 2: accept the event and write the file
|
||||
test_writer.close()
|
||||
# step 2: accept the event and write the file
|
||||
|
||||
log.debug("finished test_tensor_summary_sample")
|
||||
log.debug("finished test_tensor_summary_sample")
|
||||
|
||||
|
||||
def get_test_data_check(step):
|
||||
|
@ -131,23 +130,20 @@ def test_tensor_summary_with_ge():
|
|||
log.debug("begin test_tensor_summary_with_ge")
|
||||
|
||||
# step 0: create the thread
|
||||
test_writer = SummaryRecord(SUMMARY_DIR)
|
||||
with SummaryRecord(SUMMARY_DIR) as test_writer:
|
||||
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
# step 1: create the network for summary
|
||||
x = Tensor(np.array([1.1]).astype(np.float32))
|
||||
y = Tensor(np.array([1.2]).astype(np.float32))
|
||||
net = SummaryDemo()
|
||||
net.set_train()
|
||||
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([[i], [i]]).astype(np.float32))
|
||||
y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
# step 2: create the Event
|
||||
steps = 100
|
||||
for i in range(1, steps):
|
||||
x = Tensor(np.array([[i], [i]]).astype(np.float32))
|
||||
y = Tensor(np.array([[i + 1], [i + 1]]).astype(np.float32))
|
||||
net(x, y)
|
||||
test_writer.record(i)
|
||||
|
||||
# step 3: close the writer
|
||||
test_writer.close()
|
||||
|
||||
log.debug("finished test_tensor_summary_with_ge")
|
||||
log.debug("finished test_tensor_summary_with_ge")
|
||||
|
|
Loading…
Reference in New Issue