forked from OSSInnovation/mindspore
cpu support summary
This commit is contained in:
parent
69cbf58517
commit
ab6b6add0b
|
@ -22,10 +22,23 @@ namespace device {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size_t size, TypeId type,
|
bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size_t size, TypeId type,
|
||||||
void *host_ptr) const {
|
void *host_ptr) const {
|
||||||
if (type == kNumberTypeFloat16) {
|
MS_EXCEPTION_IF_NULL(ptr_);
|
||||||
|
|
||||||
|
if (host_ptr == ptr_) {
|
||||||
|
MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type == type_id_) {
|
||||||
|
(void)memcpy_s(host_ptr, size, ptr_, size);
|
||||||
|
} else if (type == kNumberTypeFloat16) {
|
||||||
FloatToHalf(host_ptr, ptr_, size / 2);
|
FloatToHalf(host_ptr, ptr_, size / 2);
|
||||||
} else if (type == kNumberTypeFloat64) {
|
} else if (type == kNumberTypeFloat64) {
|
||||||
FloatToDouble(host_ptr, ptr_, size / sizeof(double));
|
FloatToDouble(host_ptr, ptr_, size / sizeof(double));
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type)
|
||||||
|
<< ".";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "utils/config_manager.h"
|
#include "utils/config_manager.h"
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
#include "session/anf_runtime_algorithm.h"
|
#include "session/anf_runtime_algorithm.h"
|
||||||
|
#include "session/session_basic.h"
|
||||||
#include "operator/ops.h"
|
#include "operator/ops.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -234,9 +235,18 @@ void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<ker
|
||||||
input_list->push_back(input);
|
input_list->push_back(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
|
||||||
|
resource_manager_.IncreaseSummaryRefCount(summary_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
|
||||||
|
resource_manager_.DecreaseSummaryRefCount(summary_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) {
|
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
resource_manager_.ResetAddressRefCount(kernel_graph);
|
resource_manager_.IncreaseAddressRefCount(kernel_graph);
|
||||||
|
|
||||||
auto kernels = kernel_graph->execution_order();
|
auto kernels = kernel_graph->execution_order();
|
||||||
for (const auto &kernel : kernels) {
|
for (const auto &kernel : kernels) {
|
||||||
std::vector<kernel::AddressPtr> kernel_inputs;
|
std::vector<kernel::AddressPtr> kernel_inputs;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "device/kernel_runtime.h"
|
#include "device/kernel_runtime.h"
|
||||||
#include "session/kernel_graph.h"
|
#include "session/kernel_graph.h"
|
||||||
|
#include "session/session_basic.h"
|
||||||
#include "device/cpu/cpu_resource_manager.h"
|
#include "device/cpu/cpu_resource_manager.h"
|
||||||
#include "utils/any.h"
|
#include "utils/any.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -37,6 +38,8 @@ class CPUKernelRuntime : public KernelRuntime {
|
||||||
void AssignKernelAddress(session::KernelGraph *kernel_graph);
|
void AssignKernelAddress(session::KernelGraph *kernel_graph);
|
||||||
void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
|
void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
VectorRef *outputs);
|
VectorRef *outputs);
|
||||||
|
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool SyncStream() override { return true; };
|
bool SyncStream() override { return true; };
|
||||||
|
|
|
@ -76,7 +76,47 @@ void CPUResourceManager::MemFree(void *ptr) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CPUResourceManager::ResetAddressRefCount(const session::KernelGraph *graph) {
|
void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
|
||||||
|
if (!dynamic_malloc_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (summary_outputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &output_item : summary_outputs) {
|
||||||
|
auto node = output_item.second.first;
|
||||||
|
size_t index = IntToSize(output_item.second.second);
|
||||||
|
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->ref_count_++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) {
|
||||||
|
if (!dynamic_malloc_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (summary_outputs.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &output_item : summary_outputs) {
|
||||||
|
auto node = output_item.second.first;
|
||||||
|
size_t index = IntToSize(output_item.second.second);
|
||||||
|
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
|
||||||
|
MS_EXCEPTION_IF_NULL(address);
|
||||||
|
address->ref_count_--;
|
||||||
|
if (address->ref_count_ == 0 && address->ptr_ != nullptr) {
|
||||||
|
MemFree(address->ptr_);
|
||||||
|
address->ptr_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) {
|
||||||
if (!dynamic_malloc_) {
|
if (!dynamic_malloc_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "session/kernel_graph.h"
|
#include "session/kernel_graph.h"
|
||||||
|
#include "session/session_basic.h"
|
||||||
#include "device/device_address.h"
|
#include "device/device_address.h"
|
||||||
#include "device/cpu/cpu_simple_mem_plan.h"
|
#include "device/cpu/cpu_simple_mem_plan.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -31,10 +32,12 @@ class CPUResourceManager {
|
||||||
|
|
||||||
void MemPlan(const session::KernelGraph *graph);
|
void MemPlan(const session::KernelGraph *graph);
|
||||||
void MemMalloc(const session::KernelGraph *graph);
|
void MemMalloc(const session::KernelGraph *graph);
|
||||||
void ResetAddressRefCount(const session::KernelGraph *graph);
|
void IncreaseAddressRefCount(const session::KernelGraph *graph);
|
||||||
void DecreaseAddressRefCount(const AnfNodePtr &kernel);
|
void DecreaseAddressRefCount(const AnfNodePtr &kernel);
|
||||||
void *MemMalloc(size_t mem_size);
|
void *MemMalloc(size_t mem_size);
|
||||||
void MemFree(void *ptr);
|
void MemFree(void *ptr);
|
||||||
|
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void MemFree();
|
void MemFree();
|
||||||
|
|
|
@ -68,11 +68,25 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
|
||||||
predictmodel::StepConvertWeight(inputs);
|
predictmodel::StepConvertWeight(inputs);
|
||||||
auto execution_order = kernel_graph->execution_order();
|
auto execution_order = kernel_graph->execution_order();
|
||||||
Reorder(&execution_order);
|
Reorder(&execution_order);
|
||||||
|
|
||||||
|
bool enable_summary = summary_callback_ != nullptr;
|
||||||
kernel_graph->set_execution_order(execution_order);
|
kernel_graph->set_execution_order(execution_order);
|
||||||
|
NamedSummaryOutputs summary_outputs;
|
||||||
|
if (enable_summary) {
|
||||||
|
GetSummaryNodes(kernel_graph.get(), &summary_outputs);
|
||||||
|
runtime_.IncreaseSummaryRefCount(summary_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
bool ret = runtime_.Run(kernel_graph.get());
|
bool ret = runtime_.Run(kernel_graph.get());
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Run graph failed";
|
MS_LOG(EXCEPTION) << "Run graph failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enable_summary) {
|
||||||
|
Summary(kernel_graph.get());
|
||||||
|
runtime_.DecreaseSummaryRefCount(summary_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Run graph end";
|
MS_LOG(INFO) << "Run graph end";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -745,8 +745,7 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
|
||||||
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SessionBasic::GetSummaryNodes(const KernelGraph *graph,
|
void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) {
|
||||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
|
||||||
MS_LOG(DEBUG) << "Update summary Start";
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(summary);
|
MS_EXCEPTION_IF_NULL(summary);
|
||||||
|
@ -780,7 +779,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
|
NamedSummaryOutputs summary_outputs;
|
||||||
GetSummaryNodes(graph, &summary_outputs);
|
GetSummaryNodes(graph, &summary_outputs);
|
||||||
// do not exist summary node
|
// do not exist summary node
|
||||||
if (summary_outputs.empty()) {
|
if (summary_outputs.empty()) {
|
||||||
|
|
|
@ -130,6 +130,7 @@ class SessionBasic {
|
||||||
};
|
};
|
||||||
|
|
||||||
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
using SessionPtr = std::shared_ptr<session::SessionBasic>;
|
||||||
|
using NamedSummaryOutputs = std::unordered_map<std::string, std::pair<AnfNodePtr, int>>;
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Summary cpu st."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from tests.summary_utils import SummaryReader
|
||||||
|
from mindspore.train.summary.summary_record import SummaryRecord
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scalar_summary = P.ScalarSummary()
|
||||||
|
self.image_summary = P.ImageSummary()
|
||||||
|
self.tensor_summary = P.TensorSummary()
|
||||||
|
self.histogram_summary = P.HistogramSummary()
|
||||||
|
|
||||||
|
def construct(self, image_tensor):
|
||||||
|
self.image_summary("image", image_tensor)
|
||||||
|
self.tensor_summary("tensor", image_tensor)
|
||||||
|
self.histogram_summary("histogram", image_tensor)
|
||||||
|
scalar = image_tensor[0][0][0][0]
|
||||||
|
self.scalar_summary("scalar", scalar)
|
||||||
|
return scalar
|
||||||
|
|
||||||
|
|
||||||
|
def train_summary_record(test_writer, steps):
|
||||||
|
"""Train and record summary."""
|
||||||
|
net = SummaryNet()
|
||||||
|
out_me_dict = {}
|
||||||
|
for i in range(0, steps):
|
||||||
|
image_tensor = Tensor(np.array([[[[i]]]]).astype(np.float32))
|
||||||
|
out_put = net(image_tensor)
|
||||||
|
test_writer.record(i)
|
||||||
|
out_me_dict[i] = out_put.asnumpy()
|
||||||
|
return out_me_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TestCpuSummary:
|
||||||
|
"""Test cpu summary."""
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_summary_step2_summary_record1(self):
|
||||||
|
"""Test record 10 step summary."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
steps = 2
|
||||||
|
with SummaryRecord(tmp_dir) as test_writer:
|
||||||
|
train_summary_record(test_writer, steps=steps)
|
||||||
|
|
||||||
|
file_name = os.path.realpath(test_writer.full_file_name)
|
||||||
|
with SummaryReader(file_name) as summary_writer:
|
||||||
|
for _ in range(steps):
|
||||||
|
event = summary_writer.read_event()
|
||||||
|
tags = set(value.tag for value in event.summary.value)
|
||||||
|
assert tags == {'tensor', 'histogram', 'scalar', 'image'}
|
|
@ -22,22 +22,44 @@ _HEADER_CRC_SIZE = 4
|
||||||
_DATA_CRC_SIZE = 4
|
_DATA_CRC_SIZE = 4
|
||||||
|
|
||||||
|
|
||||||
class SummaryReader:
|
class _EndOfSummaryFileException(Exception):
|
||||||
"""Read events from summary file."""
|
"""Indicates the summary file is exhausted."""
|
||||||
|
|
||||||
def __init__(self, file_name):
|
|
||||||
self._file_name = file_name
|
class SummaryReader:
|
||||||
self._file_handler = open(self._file_name, "rb")
|
"""
|
||||||
# skip version event
|
Basic summary read function.
|
||||||
self.read_event()
|
|
||||||
|
Args:
|
||||||
|
canonical_file_path (str): The canonical summary file path.
|
||||||
|
ignore_version_event (bool): Whether ignore the version event at the beginning of summary file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, canonical_file_path, ignore_version_event=True):
|
||||||
|
self._file_path = canonical_file_path
|
||||||
|
self._ignore_version_event = ignore_version_event
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._file_handler = open(self._file_path, "rb")
|
||||||
|
if self._ignore_version_event:
|
||||||
|
self.read_event()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *unused_args):
|
||||||
|
self._file_handler.close()
|
||||||
|
return False
|
||||||
|
|
||||||
def read_event(self):
|
def read_event(self):
|
||||||
"""Read next event."""
|
"""Read next event."""
|
||||||
file_handler = self._file_handler
|
file_handler = self._file_handler
|
||||||
header = file_handler.read(_HEADER_SIZE)
|
header = file_handler.read(_HEADER_SIZE)
|
||||||
data_len = struct.unpack('Q', header)[0]
|
data_len = struct.unpack('Q', header)[0]
|
||||||
|
# Ignore crc check.
|
||||||
file_handler.read(_HEADER_CRC_SIZE)
|
file_handler.read(_HEADER_CRC_SIZE)
|
||||||
|
|
||||||
event_str = file_handler.read(data_len)
|
event_str = file_handler.read(data_len)
|
||||||
|
# Ignore crc check.
|
||||||
file_handler.read(_DATA_CRC_SIZE)
|
file_handler.read(_DATA_CRC_SIZE)
|
||||||
summary_event = summary_pb2.Event.FromString(event_str)
|
summary_event = summary_pb2.Event.FromString(event_str)
|
||||||
|
|
||||||
return summary_event
|
return summary_event
|
|
@ -22,7 +22,7 @@ import numpy as np
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
|
from mindspore.train.summary._summary_adapter import _calc_histogram_bins
|
||||||
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
|
||||||
from .summary_reader import SummaryReader
|
from tests.summary_utils import SummaryReader
|
||||||
|
|
||||||
CUR_DIR = os.getcwd()
|
CUR_DIR = os.getcwd()
|
||||||
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
|
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
|
||||||
|
@ -57,9 +57,9 @@ def test_histogram_summary():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
assert event.summary.value[0].histogram.count == 6
|
assert event.summary.value[0].histogram.count == 6
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_multi_summary():
|
def test_histogram_multi_summary():
|
||||||
|
@ -79,10 +79,10 @@ def test_histogram_multi_summary():
|
||||||
test_writer.record(step=i)
|
test_writer.record(step=i)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
for _ in range(num_step):
|
for _ in range(num_step):
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
assert event.summary.value[0].histogram.count == size
|
assert event.summary.value[0].histogram.count == size
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_scalar_tensor():
|
def test_histogram_summary_scalar_tensor():
|
||||||
|
@ -94,9 +94,9 @@ def test_histogram_summary_scalar_tensor():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
assert event.summary.value[0].histogram.count == 1
|
assert event.summary.value[0].histogram.count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_empty_tensor():
|
def test_histogram_summary_empty_tensor():
|
||||||
|
@ -108,9 +108,9 @@ def test_histogram_summary_empty_tensor():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
assert event.summary.value[0].histogram.count == 0
|
assert event.summary.value[0].histogram.count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_same_value():
|
def test_histogram_summary_same_value():
|
||||||
|
@ -125,11 +125,11 @@ def test_histogram_summary_same_value():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
LOG.debug(event)
|
LOG.debug(event)
|
||||||
|
|
||||||
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
|
assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2)
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_high_dims():
|
def test_histogram_summary_high_dims():
|
||||||
|
@ -145,11 +145,11 @@ def test_histogram_summary_high_dims():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
LOG.debug(event)
|
LOG.debug(event)
|
||||||
|
|
||||||
assert event.summary.value[0].histogram.count == tensor_data.size
|
assert event.summary.value[0].histogram.count == tensor_data.size
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_nan_inf():
|
def test_histogram_summary_nan_inf():
|
||||||
|
@ -169,11 +169,11 @@ def test_histogram_summary_nan_inf():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
LOG.debug(event)
|
LOG.debug(event)
|
||||||
|
|
||||||
assert event.summary.value[0].histogram.nan_count == 1
|
assert event.summary.value[0].histogram.nan_count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_histogram_summary_all_nan_inf():
|
def test_histogram_summary_all_nan_inf():
|
||||||
|
@ -185,11 +185,11 @@ def test_histogram_summary_all_nan_inf():
|
||||||
test_writer.record(step=1)
|
test_writer.record(step=1)
|
||||||
|
|
||||||
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
|
||||||
reader = SummaryReader(file_name)
|
with SummaryReader(file_name) as reader:
|
||||||
event = reader.read_event()
|
event = reader.read_event()
|
||||||
LOG.debug(event)
|
LOG.debug(event)
|
||||||
|
|
||||||
histogram = event.summary.value[0].histogram
|
histogram = event.summary.value[0].histogram
|
||||||
assert histogram.nan_count == 3
|
assert histogram.nan_count == 3
|
||||||
assert histogram.pos_inf_count == 1
|
assert histogram.pos_inf_count == 1
|
||||||
assert histogram.neg_inf_count == 1
|
assert histogram.neg_inf_count == 1
|
||||||
|
|
Loading…
Reference in New Issue