fix security wrap

This commit is contained in:
jiangshuqiang 2021-09-26 10:20:02 +08:00
parent dbbe512db4
commit 13fc005562
13 changed files with 43 additions and 41 deletions

View File

@ -29,6 +29,8 @@ from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strate
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
MAX_PATH_LENGTH = 1024
def _convert_type(types):
"""
@ -75,16 +77,27 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
return exec_dataset
def _make_directory(path: str):
def _make_directory(path, arg_name='path'):
"""Make directory."""
if path is None or not isinstance(path, str) or path.strip() == "":
logger.error("The path(%r) is invalid type.", path)
raise TypeError("Input path is invalid type")
if not isinstance(path, str):
logger.error("The %s is invalid, the type should be string.", arg_name)
raise TypeError("The {} is invalid, the type should be string.".format(arg_name))
if path.strip() == "":
logger.error("The %s is invalid, it should be non-blank.", arg_name)
raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name))
path = os.path.realpath(path)
if len(path) > MAX_PATH_LENGTH:
logger.error("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH)
raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH))
logger.debug("The abs path is %r", path)
if os.path.exists(path):
if not os.path.isdir(path):
logger.error("The path(%r) is a file path, it should be a directory path.", path)
raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path))
real_path = path
else:
logger.debug("The directory(%s) doesn't exist, will create it", path)

View File

@ -35,7 +35,7 @@ from mindspore.train import lineage_pb2
from mindspore.train.callback._dataset_graph import DatasetGraph
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type
from mindspore.train._utils import check_value_type, _make_directory
from ..._c_expression import security
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
@ -201,7 +201,7 @@ class SummaryCollector(Callback):
super(SummaryCollector, self).__init__()
self._summary_dir = self._process_summary_dir(summary_dir)
self._summary_dir = _make_directory(summary_dir, "summary_dir")
self._record = None
self._check_positive('collect_freq', collect_freq)
@ -242,23 +242,6 @@ class SummaryCollector(Callback):
def __exit__(self, *err):
self._record.close()
@staticmethod
def _process_summary_dir(summary_dir):
"""Check the summary dir, and create a new directory if it not exists."""
check_value_type('summary_dir', summary_dir, str)
summary_dir = summary_dir.strip()
if not summary_dir:
raise ValueError('For `summary_dir` the value should be a valid string of path, but got empty string.')
summary_dir = os.path.realpath(summary_dir)
if not os.path.exists(summary_dir):
os.makedirs(summary_dir, exist_ok=True)
else:
if not os.path.isdir(summary_dir):
raise NotADirectoryError('For `summary_dir` it should be a directory path.')
return summary_dir
@staticmethod
def _check_positive(name, value, allow_none=False):
"""Check if the value to be int type and positive."""
@ -333,13 +316,13 @@ class SummaryCollector(Callback):
def _process_specified_data(self, specified_data, action):
"""Check specified data type and value."""
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
if specified_data is None:
if action:
return dict(self._DEFAULT_SPECIFIED_DATA)
return dict()
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
for param_name in specified_data:
check_value_type(param_name, param_name, [str])

View File

@ -118,7 +118,7 @@ class WriterPool(ctx.Process):
elif action == 'END':
break
except queue.Empty:
pass
continue
for result in deq:
for plugin, data in result.get():

View File

@ -162,7 +162,7 @@ class SummaryRecord:
Validator.check_str_by_regular(file_prefix)
Validator.check_str_by_regular(file_suffix)
log_path = _make_directory(log_dir)
log_path = _make_directory(log_dir, "log_dir")
if not isinstance(max_file_size, (int, type(None))):
raise TypeError("The 'max_file_size' should be int type.")

View File

@ -21,6 +21,7 @@ from shutil import disk_usage
import numpy as np
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
from mindspore import log as logger
from .._utils import _make_directory
from ._summary_adapter import package_init_event
@ -83,7 +84,7 @@ class BaseWriter:
try:
os.chmod(self._filepath, FILE_MODE)
except FileNotFoundError:
pass
logger.debug("The summary file %r has been removed.", self._filepath)
if self._writer is not None:
self._writer.Shut()

View File

@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.train.summary.summary_record import SummaryRecord
from tests.summary_utils import SummaryReader
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -384,7 +385,7 @@ def train_summary_record(test_writer, steps):
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summary():
with tempfile.TemporaryDirectory() as tmp_dir:
steps = 2

View File

@ -22,6 +22,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.profiler import Profiler
from tests.security_utils import security_off_wrap
class Net(nn.Cell):
@ -41,7 +42,7 @@ y = np.random.randn(1, 3, 3, 4).astype(np.float32)
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_ascend_profiling():
if os.path.isdir("./data_ascend_profiler"):
shutil.rmtree("./data_ascend_profiler")

View File

@ -25,6 +25,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.profiler import Profiler
from tests.security_utils import security_off_wrap
class Net(nn.Cell):
@ -39,7 +40,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_cpu_profiling():
if sys.platform != 'linux':
return

View File

@ -26,6 +26,7 @@ from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.train.summary.summary_record import SummaryRecord
from tests.summary_utils import SummaryReader
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -62,7 +63,7 @@ def train_summary_record(test_writer, steps):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summary_step2_summary_record1():
"""Test record 10 step summary."""
if platform.system() == "Windows":

View File

@ -30,6 +30,7 @@ from mindspore.train import Model
from mindspore.train.callback import SummaryCollector
from tests.st.summary.dataset import create_mnist_dataset
from tests.summary_utils import SummaryReader
from tests.security_utils import security_off_wrap
class LeNet5(nn.Cell):
@ -126,7 +127,7 @@ class TestSummary:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summary_with_sink_mode_false(self):
"""Test summary with sink mode false, and num samples is 64."""
summary_dir = self._run_network(num_samples=10)
@ -149,7 +150,7 @@ class TestSummary:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summary_with_sink_mode_true(self):
"""Test summary with sink mode true, and num samples is 64."""
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
@ -169,7 +170,7 @@ class TestSummary:
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summarycollector_user_defind(self):
"""Test SummaryCollector with user-defined."""
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,

View File

@ -28,6 +28,7 @@ from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from tests.st.summary.dataset import create_mnist_dataset
from tests.security_utils import security_off_wrap
class LeNet5(nn.Cell):
@ -93,7 +94,7 @@ class TestSummaryOps:
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.security_off
@security_off_wrap
def test_summary_ops(self):
"""Test summary operators."""
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)

View File

@ -92,12 +92,11 @@ class TestSummaryCollector:
if isinstance(summary_dir, str):
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
'but got empty string.'
assert str(exc.value) == "The summary_dir is invalid, it should be non-blank."
else:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
assert "The summary_dir is invalid, the type should be string." in str(exc.value)
@security_off_wrap
def test_params_with_summary_dir_not_dir(self):

View File

@ -62,7 +62,7 @@ class TestSummaryRecord:
@security_off_wrap
@pytest.mark.parametrize("log_dir", ["", None, 32])
def test_log_dir_with_type_error(self, log_dir):
with pytest.raises(TypeError):
with pytest.raises((TypeError, ValueError)):
with SummaryRecord(log_dir):
pass
@ -70,7 +70,7 @@ class TestSummaryRecord:
@pytest.mark.parametrize("raise_exception", ["", None, 32])
def test_raise_exception_with_type_error(self, raise_exception):
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
with pytest.raises((TypeError, ValueError)) as exc:
with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception):
pass