forked from mindspore-Ecosystem/mindspore
!9277 fix schema & zip validation
From: @luoyang42 Reviewed-by: @pandoublefeng,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
920bbf1541
|
@ -782,6 +782,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SchemaObj::FromJSONString(const std::string &json_string) {
|
||||
try {
|
||||
nlohmann::json js = nlohmann::json::parse(json_string);
|
||||
|
@ -794,6 +795,16 @@ Status SchemaObj::FromJSONString(const std::string &json_string) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SchemaObj::ParseColumnString(const std::string &json_string) {
|
||||
try {
|
||||
nlohmann::json js = nlohmann::json::parse(json_string);
|
||||
RETURN_IF_NOT_OK(parse_column(js));
|
||||
} catch (const std::exception &err) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// OTHER FUNCTIONS
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -42,6 +42,8 @@ PYBIND_REGISTER(
|
|||
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("add_column", [](SchemaObj &self, std::string name,
|
||||
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
|
||||
.def("parse_columns",
|
||||
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.ParseColumnString(json_string)); })
|
||||
.def("to_json", &SchemaObj::to_json)
|
||||
.def("to_string", &SchemaObj::to_string)
|
||||
.def("from_string",
|
||||
|
|
|
@ -424,6 +424,8 @@ class SchemaObj {
|
|||
|
||||
Status FromJSONString(const std::string &json_string);
|
||||
|
||||
Status ParseColumnString(const std::string &json_string);
|
||||
|
||||
private:
|
||||
/// \brief Parse the columns and add it to columns
|
||||
/// \param[in] columns dataset attribution information, decoded from schema file.
|
||||
|
|
|
@ -96,6 +96,9 @@ def zip(datasets):
|
|||
if len(datasets) <= 1:
|
||||
raise ValueError(
|
||||
"Can't zip empty or just one dataset!")
|
||||
for dataset in datasets:
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
|
||||
return ZipDataset(datasets)
|
||||
|
||||
|
||||
|
@ -2452,9 +2455,6 @@ class ZipDataset(Dataset):
|
|||
|
||||
def __init__(self, datasets):
|
||||
super().__init__(children=datasets)
|
||||
for dataset in datasets:
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
|
||||
self.datasets = datasets
|
||||
|
||||
def parse(self, children=None):
|
||||
|
@ -4480,6 +4480,33 @@ class Schema:
|
|||
else:
|
||||
self.cpp_schema.add_column(name, col_type, shape)
|
||||
|
||||
def parse_columns(self, columns):
|
||||
"""
|
||||
Parse the columns and add it to self.
|
||||
|
||||
Args:
|
||||
columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file.
|
||||
|
||||
- list[dict], 'name' and 'type' must be in keys, 'shape' optional.
|
||||
|
||||
- dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If failed to parse columns.
|
||||
RuntimeError: If unknown items in columns.
|
||||
RuntimeError: If column's name field is missing.
|
||||
RuntimeError: If column's type field is missing.
|
||||
|
||||
Example:
|
||||
>>> schema = Schema()
|
||||
>>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
|
||||
>>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
|
||||
>>> schema.parse_columns(columns1)
|
||||
>>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
|
||||
>>> schema.parse_columns(columns2)
|
||||
"""
|
||||
self.cpp_schema.parse_columns(json.dumps(columns, indent=2))
|
||||
|
||||
def to_json(self):
|
||||
"""
|
||||
Get a JSON string of the schema.
|
||||
|
|
|
@ -50,6 +50,12 @@ def test_schema_exception():
|
|||
ds.Schema(1)
|
||||
assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
schema = ds.Schema(SCHEMA_FILE)
|
||||
columns = [{'type': 'int8', 'shape': [3, 3]}]
|
||||
schema.parse_columns(columns)
|
||||
assert "Column's name is missing" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_schema_simple()
|
||||
|
|
|
@ -250,16 +250,46 @@ def test_zip_exception_06():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
|
||||
def test_zip_exception_07():
|
||||
"""
|
||||
Test zip: zip with string as parameter
|
||||
"""
|
||||
logger.info("test_zip_exception_07")
|
||||
|
||||
try:
|
||||
dataz = ds.zip(('dataset1', 'dataset2'))
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert False
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
try:
|
||||
data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
|
||||
dataz = data.zip(('dataset1',))
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
assert False
|
||||
|
||||
except Exception as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zip_01()
|
||||
#test_zip_02()
|
||||
#test_zip_03()
|
||||
#test_zip_04()
|
||||
#test_zip_05()
|
||||
#test_zip_06()
|
||||
#test_zip_exception_01()
|
||||
#test_zip_exception_02()
|
||||
#test_zip_exception_03()
|
||||
#test_zip_exception_04()
|
||||
#test_zip_exception_05()
|
||||
#test_zip_exception_06()
|
||||
test_zip_02()
|
||||
test_zip_03()
|
||||
test_zip_04()
|
||||
test_zip_05()
|
||||
test_zip_06()
|
||||
test_zip_exception_01()
|
||||
test_zip_exception_02()
|
||||
test_zip_exception_03()
|
||||
test_zip_exception_04()
|
||||
test_zip_exception_05()
|
||||
test_zip_exception_06()
|
||||
test_zip_exception_07()
|
||||
|
|
Loading…
Reference in New Issue