diff --git a/mindspore/train/summary/_writer_pool.py b/mindspore/train/summary/_writer_pool.py index da0c0b255e..667f90f0b7 100644 --- a/mindspore/train/summary/_writer_pool.py +++ b/mindspore/train/summary/_writer_pool.py @@ -14,10 +14,13 @@ # ============================================================================ """Write events to disk in a base directory.""" import os +import sys import time import signal from collections import deque +import psutil + import mindspore.log as logger from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum @@ -78,6 +81,7 @@ class WriterPool(ctx.Process): self._queue, self._writers_ = ctx.Queue(ctx.cpu_count() * 2), None self._max_file_size = max_file_size self._raise_exception = raise_exception + self._training_pid = os.getpid() self.start() def run(self): @@ -97,10 +101,7 @@ class WriterPool(ctx.Process): with ctx.Pool(min(ctx.cpu_count(), 32)) as pool: deq = deque() while True: - if not self._writers: - logger.warning("Can not find any writer to write summary data, " - "so SummaryRecord will not record data.") - break + self._check_heartbeat() while deq and deq[0].ready(): for plugin, data in deq.popleft().get(): @@ -163,6 +164,7 @@ class WriterPool(ctx.Process): """Close the writers in the subprocess.""" for writer in self._writers: writer.close() + super().close() def write(self, data) -> None: """ @@ -180,4 +182,19 @@ class WriterPool(ctx.Process): def close(self) -> None: """Close the writer.""" self._queue.put(('END', None)) - self.join() + + def _check_heartbeat(self): + """Check if the summary process should survive.""" + is_exit = False + if not psutil.pid_exists(self._training_pid): + logger.warning("The training process %d is killed, summary process will exit.", self._training_pid) + is_exit = True + + if not self._writers: + logger.warning("Can not find any writer to write summary data, " + "so SummaryRecord will not record data.") + is_exit = True + + if is_exit: + self._close() + sys.exit(1) diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 3ca30ed493..aed25602ad 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -399,6 +399,7 @@ class SummaryRecord: logger.info('Please wait it may take quite some time to finish writing and closing.') atexit.unregister(self.close) self._event_writer.close() + self._event_writer.join() self._closed = True @staticmethod diff --git a/requirements.txt b/requirements.txt index ee3acb9d70..57943f3efd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ astunparse >= 1.6.3 packaging >= 20.0 pycocotools >= 2.0.0 # for st test tables >= 3.6.1 # for st test +psutil >= 5.6.1 diff --git a/setup.py b/setup.py index 58aa652355..15872ad239 100644 --- a/setup.py +++ b/setup.py @@ -121,7 +121,8 @@ required_package = [ 'decorator >= 4.4.0', 'setuptools >= 40.8.0', 'astunparse >= 1.6.3', - 'packaging >= 20.0' + 'packaging >= 20.0', + 'psutil >= 5.6.1' ] package_data = { diff --git a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py b/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py deleted file mode 100644 index 7beb13dd81..0000000000 --- a/tests/ut/python/train/summary/test_summary_ops_params_valid_check.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Test summary function of ops params valid check.""" -import os -import tempfile -import shutil -from enum import Enum - -import numpy as np -import pytest - -import mindspore.nn as nn -from mindspore.common.tensor import Tensor -from mindspore.ops import operations as P -from mindspore.train.summary.summary_record import SummaryRecord - - -class SummaryEnum(Enum): - """Summary enum.""" - IMAGE = P.ImageSummary.__name__ - SCALAR = P.ScalarSummary.__name__ - TENSOR = P.TensorSummary.__name__ - HISTOGRAM = P.HistogramSummary.__name__ - - -class SummaryNet(nn.Cell): - """Summary net definition.""" - def __init__(self, summary_type, tag, data): - super(SummaryNet, self).__init__() - self.tag = tag - self.data = data - self.summary_fn = getattr(P, summary_type)() - self.one = Tensor(np.array([1]).astype(np.float32)) - self.add = P.Add() - - def construct(self): - self.summary_fn(self.tag, self.data) - return self.add(self.one, self.one) - - -class TestSummaryOps: - """Test summary operators.""" - summary_dir = '' - - @classmethod - def run_case(cls, net): - """ run_case """ - net.set_train() - steps = 10 - with SummaryRecord(cls.summary_dir) as test_writer: - for i in range(1, steps): - net() - test_writer.record(i) - - @classmethod - def setup_class(cls): - """Run before class.""" - if not os.path.exists(cls.summary_dir): - cls.summary_dir = tempfile.mkdtemp(suffix='_summary') - - @classmethod - def teardown_class(cls): - """Run after class.""" - if os.path.exists(cls.summary_dir): - shutil.rmtree(cls.summary_dir) - - @pytest.mark.parametrize( - "summary_type, value", - [ - (SummaryEnum.SCALAR.value, Tensor(1)), - (SummaryEnum.SCALAR.value, Tensor(np.array([1]))), - (SummaryEnum.IMAGE.value, Tensor(np.array([[[[1], [2], [3], [4]]]]))), - (SummaryEnum.TENSOR.value, Tensor(np.array([[1], [2], [3], [4]]))), - (SummaryEnum.HISTOGRAM.value, Tensor(np.array([[1], [2], [3], [4]]))), - ]) - def test_summary_success(self, summary_type, value): - """Test summary success with valid tag and valid data.""" - net = SummaryNet(summary_type, tag='tag', data=value) - TestSummaryOps.run_case(net) - - @pytest.mark.parametrize( - "summary_type", - [ - SummaryEnum.SCALAR.value, - SummaryEnum.IMAGE.value, - SummaryEnum.HISTOGRAM.value, - SummaryEnum.TENSOR.value - ]) - def test_summary_tag_is_none(self, summary_type): - """Test summary tag is None, all summary operator validation rules are consistent.""" - net = SummaryNet(summary_type, tag=None, data=Tensor(0)) - with pytest.raises(TypeError): - TestSummaryOps.run_case(net) - - - @pytest.mark.parametrize( - "summary_type", - [ - SummaryEnum.SCALAR.value, - SummaryEnum.IMAGE.value, - SummaryEnum.HISTOGRAM.value, - SummaryEnum.TENSOR.value - ]) - def test_summary_tag_is_empty_string(self, summary_type): - """Test summary tag is a empty string, all summary operator validation rules are consistent.""" - net = SummaryNet(summary_type, tag='', data=Tensor(0)) - with pytest.raises(ValueError): - TestSummaryOps.run_case(net) - - @pytest.mark.parametrize("tag", [123, True, Tensor(0)]) - def test_summary_tag_is_not_string(self, tag): - """Test summary tag is not a string, all summary operator validation rules are consistent.""" - # All summary operator validation rules are consistent, so we only test scalar summary. - net = SummaryNet(SummaryEnum.SCALAR.value, tag=tag, data=Tensor(0)) - with pytest.raises(TypeError): - TestSummaryOps.run_case(net) - - @pytest.mark.parametrize("value", [123, True, 'data']) - def test_summary_value_type_invalid(self, value): - """Test the type of summary value is invalid, all summary operator validation rules are consistent.""" - # All summary operator validation rules are consistent, so we only test scalar summary. - net = SummaryNet(SummaryEnum.SCALAR.value, tag='tag', data=value) - with pytest.raises(TypeError): - TestSummaryOps.run_case(net) - - @pytest.mark.parametrize( - "summary_type, value", - [ - (SummaryEnum.IMAGE.value, Tensor(np.array([1, 2]))), - (SummaryEnum.SCALAR.value, Tensor(np.array([1, 2]))), - (SummaryEnum.TENSOR.value, Tensor(0)), - (SummaryEnum.HISTOGRAM.value, Tensor(0)) - ]) - - def test_value_shape_invalid(self, summary_type, value): - """Test invalid shape of every summary operators.""" - net = SummaryNet(summary_type, tag='tag', data=value) - with pytest.raises(ValueError): - TestSummaryOps.run_case(net)