From ff38eff9ae94adb72604be0ae47ee07dab5c1498 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Thu, 2 Apr 2020 21:56:48 +0800 Subject: [PATCH] add parameter check for Class Schema --- mindspore/dataset/engine/datasets.py | 23 ++++++++---- mindspore/dataset/engine/validators.py | 50 ++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index db2b5169d2d..de604a67e9d 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_zip, check_rename, \ check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset + check_zip_dataset, check_add_column from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -2334,13 +2334,20 @@ class Schema: self.dataset_type = '' self.num_rows = 0 else: + if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): + raise ValueError("The file %s does not exist or permission denied!" % schema_file) try: with open(schema_file, 'r') as load_f: json_obj = json.load(load_f) - self.from_json(json_obj) except json.decoder.JSONDecodeError: - raise RuntimeError("Schema file failed to load") + raise RuntimeError("Schema file failed to load.") + except UnicodeDecodeError: + raise RuntimeError("Schema file failed to decode.") + except Exception: + raise RuntimeError("Schema file failed to open.") + self.from_json(json_obj) + @check_add_column def add_column(self, name, de_type, shape=None): """ Add new column to the schema. @@ -2359,10 +2366,8 @@ class Schema: if isinstance(de_type, typing.Type): de_type = mstype_to_detype(de_type) new_column["type"] = str(de_type) - elif isinstance(de_type, str): - new_column["type"] = str(DataType(de_type)) else: - raise ValueError("Unknown column type") + new_column["type"] = str(DataType(de_type)) if shape is not None: new_column["shape"] = shape @@ -2391,7 +2396,7 @@ class Schema: Parse the columns and add it to self. Args: - columns (list[str]): names of columns. + columns (dict or list[str]): names of columns. Raises: RuntimeError: If failed to parse schema file. @@ -2399,6 +2404,8 @@ class Schema: RuntimeError: If column's name field is missing. RuntimeError: If column's type field is missing. """ + if columns is None: + raise TypeError("Expected non-empty dict or string list.") self.columns = [] for col in columns: name = None @@ -2443,6 +2450,8 @@ class Schema: RuntimeError: if dataset type is missing in the object. RuntimeError: if columns are missing in the object. """ + if not isinstance(json_obj, dict) or json_obj is None: + raise ValueError("Expected non-empty dict.") for k, v in json_obj.items(): if k == "datasetType": self.dataset_type = v diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b4d22a4a013..26d62419451 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -19,10 +19,15 @@ import inspect as ins import os from functools import wraps from multiprocessing import cpu_count +from mindspore._c_expression import typing from . import samplers from . import datasets INT32_MAX = 2147483647 +valid_detype = [ + "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", + "uint32", "uint64", "float16", "float32", "float64" +] def check(method): @@ -188,6 +193,12 @@ def check(method): return wrapper +def check_valid_detype(type_): + if type_ not in valid_detype: + raise ValueError("Unknown column type") + return True + + def check_filename(path): """ check the filename in the path @@ -743,3 +754,42 @@ def check_project(method): return method(*args, **kwargs) return new_method + + +def check_shape(shape, name): + if isinstance(shape, list): + for element in shape: + if not isinstance(element, int): + raise TypeError( + "Each element in {0} should be of type int. Got {1}.".format(name, type(element))) + else: + raise TypeError("Expected int list.") + + +def check_add_column(method): + """check the input arguments of add_column.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check name; required argument + name = param_dict.get("name") + if not isinstance(name, str) or not name: + raise TypeError("Expected non-empty string.") + + # check type; required argument + de_type = param_dict.get("de_type") + if de_type is not None: + if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): + raise ValueError("Unknown column type.") + else: + raise TypeError("Expected non-empty string.") + + # check shape + shape = param_dict.get("shape") + if shape is not None: + check_shape(shape, "shape") + + return method(*args, **kwargs) + + return new_method