fix validation of Dither and code examples of mindrecord

This commit is contained in:
Xiao Tianci 2022-08-25 10:53:28 +08:00
parent 15205b063e
commit b04ac43701
3 changed files with 17 additions and 30 deletions

View File

@ -249,7 +249,7 @@ def check_dither(method):
[density_function, noise_shaping], _ = parse_user_args(
method, *args, **kwargs)
type_check(density_function, (DensityFunction), "density_function")
type_check(density_function, (DensityFunction,), "density_function")
type_check(noise_shaping, (bool,), "noise_shaping")
return method(self, *args, **kwargs)

View File

@ -60,14 +60,10 @@ class FileWriter:
... {"file_name": "3.jpg", "label": 99,
... "data": b"\xaf\xafU<\xb8|6\xbd}\xc1\x99[\xeaj+\x8f\x84\xd3\xcc\xa0,i\xbb\xb9-\xcdz\xecp{T\xb1"}]
>>> writer = FileWriter(file_name="test.mindrecord", shard_num=1, overwrite=True)
>>> writer.add_schema(schema_json, "test_schema")
0
>>> writer.add_index(indexes)
MSRStatus.SUCCESS
>>> writer.write_raw_data(data)
MSRStatus.SUCCESS
>>> writer.commit()
MSRStatus.SUCCESS
>>> schema_id = writer.add_schema(schema_json, "test_schema")
>>> status = writer.add_index(indexes)
>>> status = writer.write_raw_data(data)
>>> status = writer.commit()
"""
def __init__(self, file_name, shard_num=1, overwrite=False):
@ -130,17 +126,12 @@ class FileWriter:
>>> data = [{"file_name": "1.jpg", "label": 0,
... "data": b"\x10c\xb3w\xa8\xee$o&<q\x8c\x8e(\xa2\x90\x90\x96\xbc\xb1\x1e\xd4QER\x13?\xff"}]
>>> writer = FileWriter(file_name="test.mindrecord", shard_num=1, overwrite=True)
>>> writer.add_schema(schema_json, "test_schema")
0
>>> writer.write_raw_data(data)
MSRStatus.SUCCESS
>>> writer.commit()
MSRStatus.SUCCESS
>>> schema_id = writer.add_schema(schema_json, "test_schema")
>>> status = writer.write_raw_data(data)
>>> status = writer.commit()
>>> write_append = FileWriter.open_for_append("test.mindrecord")
>>> write_append.write_raw_data(data)
MSRStatus.SUCCESS
>>> write_append.commit()
MSRStatus.SUCCESS
>>> status = write_append.write_raw_data(data)
>>> status = write_append.commit()
"""
if platform.system().lower() == "windows":
file_name = file_name.replace("\\", "/")
@ -352,8 +343,7 @@ class FileWriter:
Examples:
>>> from mindspore.mindrecord import FileWriter
>>> writer = FileWriter(file_name="test.mindrecord", shard_num=1)
>>> writer.set_header_size(1 << 25) # 32MB
MSRStatus.SUCCESS
>>> status = writer.set_header_size(1 << 25) # 32MB
"""
return self._writer.set_header_size(header_size)
@ -378,8 +368,7 @@ class FileWriter:
Examples:
>>> from mindspore.mindrecord import FileWriter
>>> writer = FileWriter(file_name="test.mindrecord", shard_num=1)
>>> writer.set_page_size(1 << 26) # 128MB
MSRStatus.SUCCESS
>>> status = writer.set_page_size(1 << 26) # 128MB
"""
return self._writer.set_page_size(page_size)

View File

@ -30,7 +30,7 @@ def count_unequal_element(data_expected, data_me, rtol, atol):
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
format(data_expected[greater], data_me[greater], error[greater])
def test_dither_eager_noise_shaping_false():
@ -145,14 +145,12 @@ def test_invalid_dither_input():
assert error_msg in str(error_info.value)
test_invalid_input("invalid density function parameter value", "TPDF", False, TypeError,
"Argument density_function with value TPDF is not of type"
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'str'>.")
"Argument density_function with value TPDF is not of type "
+ "[<enum 'DensityFunction'>], but got <class 'str'>.")
test_invalid_input("invalid density_function parameter value", 6, False, TypeError,
"Argument density_function with value 6 is not of type"
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'int'>.")
"Argument density_function with value 6 is not of type "
+ "[<enum 'DensityFunction'>], but got <class 'int'>.")
test_invalid_input("invalid noise_shaping parameter value", DensityFunction.GPDF, 1, TypeError,
"Argument noise_shaping with value 1 is not of type [<class 'bool'>], but got <class 'int'>.")