diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/device/cpu/cpu_device_address.cc index 56e9b6d36e..1edb248341 100644 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.cc +++ b/mindspore/ccsrc/device/cpu/cpu_device_address.cc @@ -22,10 +22,23 @@ namespace device { namespace cpu { bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, void *host_ptr) const { - if (type == kNumberTypeFloat16) { + MS_EXCEPTION_IF_NULL(ptr_); + + if (host_ptr == ptr_) { + MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; + return true; + } + + if (type == type_id_) { + (void)memcpy_s(host_ptr, size, ptr_, size); + } else if (type == kNumberTypeFloat16) { FloatToHalf(host_ptr, ptr_, size / 2); } else if (type == kNumberTypeFloat64) { FloatToDouble(host_ptr, ptr_, size / sizeof(double)); + } else { + MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) + << "."; + return false; } return true; } diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc index 67328f04c2..f10568d3d9 100644 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc @@ -27,6 +27,7 @@ #include "utils/config_manager.h" #include "common/utils.h" #include "session/anf_runtime_algorithm.h" +#include "session/session_basic.h" #include "operator/ops.h" namespace mindspore { @@ -234,9 +235,18 @@ void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vectorpush_back(input); } +void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.IncreaseSummaryRefCount(summary_outputs); +} + +void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.DecreaseSummaryRefCount(summary_outputs); +} + bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); - resource_manager_.ResetAddressRefCount(kernel_graph); + resource_manager_.IncreaseAddressRefCount(kernel_graph); + auto kernels = kernel_graph->execution_order(); for (const auto &kernel : kernels) { std::vector kernel_inputs; diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h index 28e61c1479..ac63f55d3e 100644 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h @@ -22,6 +22,7 @@ #include #include "device/kernel_runtime.h" #include "session/kernel_graph.h" +#include "session/session_basic.h" #include "device/cpu/cpu_resource_manager.h" #include "utils/any.h" namespace mindspore { @@ -37,6 +38,8 @@ class CPUKernelRuntime : public KernelRuntime { void AssignKernelAddress(session::KernelGraph *kernel_graph); void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, VectorRef *outputs); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); protected: bool SyncStream() override { return true; }; diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc index 45b9ea5bed..c69ef35305 100644 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc +++ b/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc @@ -76,7 +76,47 @@ void CPUResourceManager::MemFree(void *ptr) { } } -void CPUResourceManager::ResetAddressRefCount(const session::KernelGraph *graph) { +void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } +} + +void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } +} + +void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { if (!dynamic_malloc_) { return; } diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/device/cpu/cpu_resource_manager.h index 96cf00f3d8..d130241464 100644 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h +++ b/mindspore/ccsrc/device/cpu/cpu_resource_manager.h @@ -19,6 +19,7 @@ #include #include #include "session/kernel_graph.h" +#include "session/session_basic.h" #include "device/device_address.h" #include "device/cpu/cpu_simple_mem_plan.h" namespace mindspore { @@ -31,10 +32,12 @@ class CPUResourceManager { void MemPlan(const session::KernelGraph *graph); void MemMalloc(const session::KernelGraph *graph); - void ResetAddressRefCount(const session::KernelGraph *graph); + void IncreaseAddressRefCount(const session::KernelGraph *graph); void DecreaseAddressRefCount(const AnfNodePtr &kernel); void *MemMalloc(size_t mem_size); void MemFree(void *ptr); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); private: void MemFree(); diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc index 447845480d..c3caf512ac 100644 --- a/mindspore/ccsrc/session/cpu_session.cc +++ b/mindspore/ccsrc/session/cpu_session.cc @@ -68,11 +68,25 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vectorexecution_order(); Reorder(&execution_order); + + bool enable_summary = summary_callback_ != nullptr; kernel_graph->set_execution_order(execution_order); + NamedSummaryOutputs summary_outputs; + if (enable_summary) { + GetSummaryNodes(kernel_graph.get(), &summary_outputs); + runtime_.IncreaseSummaryRefCount(summary_outputs); + } + bool ret = runtime_.Run(kernel_graph.get()); if (!ret) { MS_LOG(EXCEPTION) << "Run graph failed"; } + + if (enable_summary) { + Summary(kernel_graph.get()); + runtime_.DecreaseSummaryRefCount(summary_outputs); + } + MS_LOG(INFO) << "Run graph end"; } diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 93cfc6bbcd..d11446a8ba 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -745,8 +745,7 @@ void SessionBasic::Reorder(std::vector *node_list) { (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); } -void SessionBasic::GetSummaryNodes(const KernelGraph *graph, - std::unordered_map> *summary) { +void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(summary); @@ -780,7 +779,7 @@ void SessionBasic::Summary(KernelGraph *graph) { return; } MS_EXCEPTION_IF_NULL(graph); - std::unordered_map> summary_outputs; + NamedSummaryOutputs summary_outputs; GetSummaryNodes(graph, &summary_outputs); // do not exist summary node if (summary_outputs.empty()) { diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 142c5b68be..4620bc763d 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -130,6 +130,7 @@ class SessionBasic { }; using SessionPtr = std::shared_ptr; +using NamedSummaryOutputs = std::unordered_map>; } // namespace session } // namespace mindspore #endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/tests/st/summary/test_cpu_summary.py b/tests/st/summary/test_cpu_summary.py new file mode 100644 index 0000000000..8d88003866 --- /dev/null +++ b/tests/st/summary/test_cpu_summary.py @@ -0,0 +1,79 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Summary cpu st.""" +import os +import tempfile + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from tests.summary_utils import SummaryReader +from mindspore.train.summary.summary_record import SummaryRecord + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class SummaryNet(nn.Cell): + def __init__(self): + super().__init__() + self.scalar_summary = P.ScalarSummary() + self.image_summary = P.ImageSummary() + self.tensor_summary = P.TensorSummary() + self.histogram_summary = P.HistogramSummary() + + def construct(self, image_tensor): + self.image_summary("image", image_tensor) + self.tensor_summary("tensor", image_tensor) + self.histogram_summary("histogram", image_tensor) + scalar = image_tensor[0][0][0][0] + self.scalar_summary("scalar", scalar) + return scalar + + +def train_summary_record(test_writer, steps): + """Train and record summary.""" + net = SummaryNet() + out_me_dict = {} + for i in range(0, steps): + image_tensor = Tensor(np.array([[[[i]]]]).astype(np.float32)) + out_put = net(image_tensor) + test_writer.record(i) + out_me_dict[i] = out_put.asnumpy() + return out_me_dict + + +class TestCpuSummary: + """Test cpu summary.""" + + @pytest.mark.level0 + @pytest.mark.platform_x86_cpu_training + @pytest.mark.env_onecard + def test_summary_step2_summary_record1(self): + """Test record 10 step summary.""" + with tempfile.TemporaryDirectory() as tmp_dir: + steps = 2 + with SummaryRecord(tmp_dir) as test_writer: + train_summary_record(test_writer, steps=steps) + + file_name = os.path.realpath(test_writer.full_file_name) + with SummaryReader(file_name) as summary_writer: + for _ in range(steps): + event = summary_writer.read_event() + tags = set(value.tag for value in event.summary.value) + assert tags == {'tensor', 'histogram', 'scalar', 'image'} diff --git a/tests/ut/python/train/summary/summary_reader.py b/tests/summary_utils.py similarity index 60% rename from tests/ut/python/train/summary/summary_reader.py rename to tests/summary_utils.py index 647c25f25c..826a3106e5 100644 --- a/tests/ut/python/train/summary/summary_reader.py +++ b/tests/summary_utils.py @@ -22,22 +22,44 @@ _HEADER_CRC_SIZE = 4 _DATA_CRC_SIZE = 4 -class SummaryReader: - """Read events from summary file.""" +class _EndOfSummaryFileException(Exception): + """Indicates the summary file is exhausted.""" - def __init__(self, file_name): - self._file_name = file_name - self._file_handler = open(self._file_name, "rb") - # skip version event - self.read_event() + +class SummaryReader: + """ + Basic summary read function. + + Args: + canonical_file_path (str): The canonical summary file path. + ignore_version_event (bool): Whether ignore the version event at the beginning of summary file. + """ + + def __init__(self, canonical_file_path, ignore_version_event=True): + self._file_path = canonical_file_path + self._ignore_version_event = ignore_version_event + + def __enter__(self): + self._file_handler = open(self._file_path, "rb") + if self._ignore_version_event: + self.read_event() + return self + + def __exit__(self, *unused_args): + self._file_handler.close() + return False def read_event(self): """Read next event.""" file_handler = self._file_handler header = file_handler.read(_HEADER_SIZE) data_len = struct.unpack('Q', header)[0] + # Ignore crc check. file_handler.read(_HEADER_CRC_SIZE) + event_str = file_handler.read(data_len) + # Ignore crc check. file_handler.read(_DATA_CRC_SIZE) summary_event = summary_pb2.Event.FromString(event_str) + return summary_event diff --git a/tests/ut/python/train/summary/test_histogram_summary.py b/tests/ut/python/train/summary/test_histogram_summary.py index dc0892167c..e304146a2e 100644 --- a/tests/ut/python/train/summary/test_histogram_summary.py +++ b/tests/ut/python/train/summary/test_histogram_summary.py @@ -22,7 +22,7 @@ import numpy as np from mindspore.common.tensor import Tensor from mindspore.train.summary._summary_adapter import _calc_histogram_bins from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data -from .summary_reader import SummaryReader +from tests.summary_utils import SummaryReader CUR_DIR = os.getcwd() SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/") @@ -57,9 +57,9 @@ def test_histogram_summary(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - assert event.summary.value[0].histogram.count == 6 + with SummaryReader(file_name) as reader: + event = reader.read_event() + assert event.summary.value[0].histogram.count == 6 def test_histogram_multi_summary(): @@ -79,10 +79,10 @@ def test_histogram_multi_summary(): test_writer.record(step=i) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - for _ in range(num_step): - event = reader.read_event() - assert event.summary.value[0].histogram.count == size + with SummaryReader(file_name) as reader: + for _ in range(num_step): + event = reader.read_event() + assert event.summary.value[0].histogram.count == size def test_histogram_summary_scalar_tensor(): @@ -94,9 +94,9 @@ def test_histogram_summary_scalar_tensor(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - assert event.summary.value[0].histogram.count == 1 + with SummaryReader(file_name) as reader: + event = reader.read_event() + assert event.summary.value[0].histogram.count == 1 def test_histogram_summary_empty_tensor(): @@ -108,9 +108,9 @@ def test_histogram_summary_empty_tensor(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - assert event.summary.value[0].histogram.count == 0 + with SummaryReader(file_name) as reader: + event = reader.read_event() + assert event.summary.value[0].histogram.count == 0 def test_histogram_summary_same_value(): @@ -125,11 +125,11 @@ def test_histogram_summary_same_value(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - LOG.debug(event) + with SummaryReader(file_name) as reader: + event = reader.read_event() + LOG.debug(event) - assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) + assert len(event.summary.value[0].histogram.buckets) == _calc_histogram_bins(dim1 * dim2) def test_histogram_summary_high_dims(): @@ -145,11 +145,11 @@ def test_histogram_summary_high_dims(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - LOG.debug(event) + with SummaryReader(file_name) as reader: + event = reader.read_event() + LOG.debug(event) - assert event.summary.value[0].histogram.count == tensor_data.size + assert event.summary.value[0].histogram.count == tensor_data.size def test_histogram_summary_nan_inf(): @@ -169,11 +169,11 @@ def test_histogram_summary_nan_inf(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - LOG.debug(event) + with SummaryReader(file_name) as reader: + event = reader.read_event() + LOG.debug(event) - assert event.summary.value[0].histogram.nan_count == 1 + assert event.summary.value[0].histogram.nan_count == 1 def test_histogram_summary_all_nan_inf(): @@ -185,11 +185,11 @@ def test_histogram_summary_all_nan_inf(): test_writer.record(step=1) file_name = os.path.join(tmp_dir, test_writer.event_file_name) - reader = SummaryReader(file_name) - event = reader.read_event() - LOG.debug(event) + with SummaryReader(file_name) as reader: + event = reader.read_event() + LOG.debug(event) - histogram = event.summary.value[0].histogram - assert histogram.nan_count == 3 - assert histogram.pos_inf_count == 1 - assert histogram.neg_inf_count == 1 + histogram = event.summary.value[0].histogram + assert histogram.nan_count == 3 + assert histogram.pos_inf_count == 1 + assert histogram.neg_inf_count == 1