forked from mindspore-Ecosystem/mindspore
fix security wrap
This commit is contained in:
parent
dbbe512db4
commit
13fc005562
|
@ -29,6 +29,8 @@ from mindspore.train.node_strategy_pb2 import ParallelStrategyMap as ckpt_strate
|
|||
|
||||
from .lineage_pb2 import DatasetGraph, TrainLineage, EvaluationLineage, UserDefinedInfo
|
||||
|
||||
MAX_PATH_LENGTH = 1024
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
|
@ -75,16 +77,27 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
|
|||
return exec_dataset
|
||||
|
||||
|
||||
def _make_directory(path: str):
|
||||
def _make_directory(path, arg_name='path'):
|
||||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invalid type")
|
||||
if not isinstance(path, str):
|
||||
logger.error("The %s is invalid, the type should be string.", arg_name)
|
||||
raise TypeError("The {} is invalid, the type should be string.".format(arg_name))
|
||||
if path.strip() == "":
|
||||
logger.error("The %s is invalid, it should be non-blank.", arg_name)
|
||||
raise ValueError("The {} is invalid, it should be non-blank.".format(arg_name))
|
||||
|
||||
path = os.path.realpath(path)
|
||||
|
||||
if len(path) > MAX_PATH_LENGTH:
|
||||
logger.error("The %s length is too long, it should be limited in %s.", arg_name, MAX_PATH_LENGTH)
|
||||
raise ValueError("The {} length is too long, it should be limited in {}.".format(arg_name, MAX_PATH_LENGTH))
|
||||
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
||||
if os.path.exists(path):
|
||||
if not os.path.isdir(path):
|
||||
logger.error("The path(%r) is a file path, it should be a directory path.", path)
|
||||
raise NotADirectoryError("The path({}) is a file path, it should be a directory path.".format(path))
|
||||
real_path = path
|
||||
else:
|
||||
logger.debug("The directory(%s) doesn't exist, will create it", path)
|
||||
|
|
|
@ -35,7 +35,7 @@ from mindspore.train import lineage_pb2
|
|||
from mindspore.train.callback._dataset_graph import DatasetGraph
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.nn.loss.loss import LossBase
|
||||
from mindspore.train._utils import check_value_type
|
||||
from mindspore.train._utils import check_value_type, _make_directory
|
||||
from ..._c_expression import security
|
||||
|
||||
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
|
||||
|
@ -201,7 +201,7 @@ class SummaryCollector(Callback):
|
|||
|
||||
super(SummaryCollector, self).__init__()
|
||||
|
||||
self._summary_dir = self._process_summary_dir(summary_dir)
|
||||
self._summary_dir = _make_directory(summary_dir, "summary_dir")
|
||||
self._record = None
|
||||
|
||||
self._check_positive('collect_freq', collect_freq)
|
||||
|
@ -242,23 +242,6 @@ class SummaryCollector(Callback):
|
|||
def __exit__(self, *err):
|
||||
self._record.close()
|
||||
|
||||
@staticmethod
|
||||
def _process_summary_dir(summary_dir):
|
||||
"""Check the summary dir, and create a new directory if it not exists."""
|
||||
check_value_type('summary_dir', summary_dir, str)
|
||||
summary_dir = summary_dir.strip()
|
||||
if not summary_dir:
|
||||
raise ValueError('For `summary_dir` the value should be a valid string of path, but got empty string.')
|
||||
|
||||
summary_dir = os.path.realpath(summary_dir)
|
||||
if not os.path.exists(summary_dir):
|
||||
os.makedirs(summary_dir, exist_ok=True)
|
||||
else:
|
||||
if not os.path.isdir(summary_dir):
|
||||
raise NotADirectoryError('For `summary_dir` it should be a directory path.')
|
||||
|
||||
return summary_dir
|
||||
|
||||
@staticmethod
|
||||
def _check_positive(name, value, allow_none=False):
|
||||
"""Check if the value to be int type and positive."""
|
||||
|
@ -333,13 +316,13 @@ class SummaryCollector(Callback):
|
|||
|
||||
def _process_specified_data(self, specified_data, action):
|
||||
"""Check specified data type and value."""
|
||||
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
|
||||
|
||||
if specified_data is None:
|
||||
if action:
|
||||
return dict(self._DEFAULT_SPECIFIED_DATA)
|
||||
return dict()
|
||||
|
||||
check_value_type('collect_specified_data', specified_data, [dict, type(None)])
|
||||
|
||||
for param_name in specified_data:
|
||||
check_value_type(param_name, param_name, [str])
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ class WriterPool(ctx.Process):
|
|||
elif action == 'END':
|
||||
break
|
||||
except queue.Empty:
|
||||
pass
|
||||
continue
|
||||
|
||||
for result in deq:
|
||||
for plugin, data in result.get():
|
||||
|
|
|
@ -162,7 +162,7 @@ class SummaryRecord:
|
|||
Validator.check_str_by_regular(file_prefix)
|
||||
Validator.check_str_by_regular(file_suffix)
|
||||
|
||||
log_path = _make_directory(log_dir)
|
||||
log_path = _make_directory(log_dir, "log_dir")
|
||||
|
||||
if not isinstance(max_file_size, (int, type(None))):
|
||||
raise TypeError("The 'max_file_size' should be int type.")
|
||||
|
|
|
@ -21,6 +21,7 @@ from shutil import disk_usage
|
|||
import numpy as np
|
||||
|
||||
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
|
||||
from mindspore import log as logger
|
||||
|
||||
from .._utils import _make_directory
|
||||
from ._summary_adapter import package_init_event
|
||||
|
@ -83,7 +84,7 @@ class BaseWriter:
|
|||
try:
|
||||
os.chmod(self._filepath, FILE_MODE)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
logger.debug("The summary file %r has been removed.", self._filepath)
|
||||
if self._writer is not None:
|
||||
self._writer.Shut()
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from tests.summary_utils import SummaryReader
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
@ -384,7 +385,7 @@ def train_summary_record(test_writer, steps):
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summary():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
steps = 2
|
||||
|
|
|
@ -22,6 +22,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.profiler import Profiler
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -41,7 +42,7 @@ y = np.random.randn(1, 3, 3, 4).astype(np.float32)
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_ascend_profiling():
|
||||
if os.path.isdir("./data_ascend_profiler"):
|
||||
shutil.rmtree("./data_ascend_profiler")
|
||||
|
|
|
@ -25,6 +25,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.profiler import Profiler
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -39,7 +40,7 @@ class Net(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_cpu_profiling():
|
||||
if sys.platform != 'linux':
|
||||
return
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.train.summary.summary_record import SummaryRecord
|
||||
from tests.summary_utils import SummaryReader
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -62,7 +63,7 @@ def train_summary_record(test_writer, steps):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summary_step2_summary_record1():
|
||||
"""Test record 10 step summary."""
|
||||
if platform.system() == "Windows":
|
||||
|
|
|
@ -30,6 +30,7 @@ from mindspore.train import Model
|
|||
from mindspore.train.callback import SummaryCollector
|
||||
from tests.st.summary.dataset import create_mnist_dataset
|
||||
from tests.summary_utils import SummaryReader
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
|
@ -126,7 +127,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summary_with_sink_mode_false(self):
|
||||
"""Test summary with sink mode false, and num samples is 64."""
|
||||
summary_dir = self._run_network(num_samples=10)
|
||||
|
@ -149,7 +150,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summary_with_sink_mode_true(self):
|
||||
"""Test summary with sink mode true, and num samples is 64."""
|
||||
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=10)
|
||||
|
@ -169,7 +170,7 @@ class TestSummary:
|
|||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summarycollector_user_defind(self):
|
||||
"""Test SummaryCollector with user-defined."""
|
||||
summary_dir = self._run_network(dataset_sink_mode=True, num_samples=2,
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.train import Model
|
||||
from mindspore.train.summary.summary_record import _get_summary_tensor_data
|
||||
from tests.st.summary.dataset import create_mnist_dataset
|
||||
from tests.security_utils import security_off_wrap
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
|
@ -93,7 +94,7 @@ class TestSummaryOps:
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.security_off
|
||||
@security_off_wrap
|
||||
def test_summary_ops(self):
|
||||
"""Test summary operators."""
|
||||
ds_train = create_mnist_dataset('train', num_samples=1, batch_size=1)
|
||||
|
|
|
@ -92,12 +92,11 @@ class TestSummaryCollector:
|
|||
if isinstance(summary_dir, str):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
|
||||
'but got empty string.'
|
||||
assert str(exc.value) == "The summary_dir is invalid, it should be non-blank."
|
||||
else:
|
||||
with pytest.raises(TypeError) as exc:
|
||||
SummaryCollector(summary_dir=summary_dir)
|
||||
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
|
||||
assert "The summary_dir is invalid, the type should be string." in str(exc.value)
|
||||
|
||||
@security_off_wrap
|
||||
def test_params_with_summary_dir_not_dir(self):
|
||||
|
|
|
@ -62,7 +62,7 @@ class TestSummaryRecord:
|
|||
@security_off_wrap
|
||||
@pytest.mark.parametrize("log_dir", ["", None, 32])
|
||||
def test_log_dir_with_type_error(self, log_dir):
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises((TypeError, ValueError)):
|
||||
with SummaryRecord(log_dir):
|
||||
pass
|
||||
|
||||
|
@ -70,7 +70,7 @@ class TestSummaryRecord:
|
|||
@pytest.mark.parametrize("raise_exception", ["", None, 32])
|
||||
def test_raise_exception_with_type_error(self, raise_exception):
|
||||
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
|
||||
with pytest.raises(TypeError) as exc:
|
||||
with pytest.raises((TypeError, ValueError)) as exc:
|
||||
with SummaryRecord(log_dir=summary_dir, raise_exception=raise_exception):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue