!9277 fix schema & zip validation

From: @luoyang42
Reviewed-by: @pandoublefeng,@liucunwei
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2020-12-01 16:28:48 +08:00 committed by Gitee
commit 920bbf1541
6 changed files with 92 additions and 14 deletions

View File

@ -782,6 +782,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) {
return Status::OK();
}
Status SchemaObj::FromJSONString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
@ -794,6 +795,16 @@ Status SchemaObj::FromJSONString(const std::string &json_string) {
return Status::OK();
}
Status SchemaObj::ParseColumnString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
RETURN_IF_NOT_OK(parse_column(js));
} catch (const std::exception &err) {
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
}
return Status::OK();
}
// OTHER FUNCTIONS
#ifndef ENABLE_ANDROID

View File

@ -42,6 +42,8 @@ PYBIND_REGISTER(
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("add_column", [](SchemaObj &self, std::string name,
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("parse_columns",
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.ParseColumnString(json_string)); })
.def("to_json", &SchemaObj::to_json)
.def("to_string", &SchemaObj::to_string)
.def("from_string",

View File

@ -424,6 +424,8 @@ class SchemaObj {
Status FromJSONString(const std::string &json_string);
Status ParseColumnString(const std::string &json_string);
private:
/// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file.

View File

@ -96,6 +96,9 @@ def zip(datasets):
if len(datasets) <= 1:
raise ValueError(
"Can't zip empty or just one dataset!")
for dataset in datasets:
if not isinstance(dataset, Dataset):
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
return ZipDataset(datasets)
@ -2452,9 +2455,6 @@ class ZipDataset(Dataset):
def __init__(self, datasets):
super().__init__(children=datasets)
for dataset in datasets:
if not isinstance(dataset, Dataset):
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
self.datasets = datasets
def parse(self, children=None):
@ -4480,6 +4480,33 @@ class Schema:
else:
self.cpp_schema.add_column(name, col_type, shape)
def parse_columns(self, columns):
"""
Parse the columns and add it to self.
Args:
columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file.
- list[dict], 'name' and 'type' must be in keys, 'shape' optional.
- dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
Raises:
RuntimeError: If failed to parse columns.
RuntimeError: If unknown items in columns.
RuntimeError: If column's name field is missing.
RuntimeError: If column's type field is missing.
Example:
>>> schema = Schema()
>>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
>>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
>>> schema.parse_columns(columns1)
>>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
>>> schema.parse_columns(columns2)
"""
self.cpp_schema.parse_columns(json.dumps(columns, indent=2))
def to_json(self):
"""
Get a JSON string of the schema.

View File

@ -50,6 +50,12 @@ def test_schema_exception():
ds.Schema(1)
assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value)
with pytest.raises(RuntimeError) as info:
schema = ds.Schema(SCHEMA_FILE)
columns = [{'type': 'int8', 'shape': [3, 3]}]
schema.parse_columns(columns)
assert "Column's name is missing" in str(info.value)
if __name__ == '__main__':
test_schema_simple()

View File

@ -250,16 +250,46 @@ def test_zip_exception_06():
logger.info("Got an exception in DE: {}".format(str(e)))
def test_zip_exception_07():
"""
Test zip: zip with string as parameter
"""
logger.info("test_zip_exception_07")
try:
dataz = ds.zip(('dataset1', 'dataset2'))
num_iter = 0
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert False
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
try:
data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
dataz = data.zip(('dataset1',))
num_iter = 0
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert False
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
if __name__ == '__main__':
test_zip_01()
#test_zip_02()
#test_zip_03()
#test_zip_04()
#test_zip_05()
#test_zip_06()
#test_zip_exception_01()
#test_zip_exception_02()
#test_zip_exception_03()
#test_zip_exception_04()
#test_zip_exception_05()
#test_zip_exception_06()
test_zip_02()
test_zip_03()
test_zip_04()
test_zip_05()
test_zip_06()
test_zip_exception_01()
test_zip_exception_02()
test_zip_exception_03()
test_zip_exception_04()
test_zip_exception_05()
test_zip_exception_06()
test_zip_exception_07()