!6142 MindData fix batch ops issue

Merge pull request !6142 from xiefangqi/md_fix_map_issue
This commit is contained in:
mindspore-ci-bot 2020-09-14 16:00:23 +08:00 committed by Gitee
commit 9824d0b20f
2 changed files with 15 additions and 4 deletions

View File

@ -510,15 +510,15 @@ def check_batch(method):
for k, v in param_dict.get('pad_info').items():
check_pad_info(k, v)
if (per_batch_map is None) != (input_columns is None):
# These two parameters appear together.
raise ValueError("per_batch_map and input_columns need to be passed in together.")
if input_columns is not None:
check_columns(input_columns, "input_columns")
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
raise ValueError("the signature of per_batch_map should match with input columns")
if (per_batch_map is None) != (input_columns is None):
# These two parameters appear together.
raise ValueError("per_batch_map and input_columns need to be passed in together.")
if output_columns is not None:
raise ValueError("output_columns is currently not implemented.")

View File

@ -466,6 +466,16 @@ def test_batch_exception_13():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "column_order is currently not implemented." in str(e)
def test_batch_exception_14():
batch_size = 2
input_columns = ["num"]
data1 = ds.TFRecordDataset(DATA_DIR)
try:
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
except ValueError as e:
assert "per_batch_map and input_columns need to be passed in together." in str(e)
if __name__ == '__main__':
test_batch_01()
test_batch_02()
@ -491,4 +501,5 @@ if __name__ == '__main__':
test_batch_exception_11()
test_batch_exception_12()
test_batch_exception_13()
test_batch_exception_14()
logger.info('\n')