forked from mindspore-Ecosystem/mindspore
!6142 MindData fix batch ops issue
Merge pull request !6142 from xiefangqi/md_fix_map_issue
This commit is contained in:
commit
9824d0b20f
|
@ -510,15 +510,15 @@ def check_batch(method):
|
|||
for k, v in param_dict.get('pad_info').items():
|
||||
check_pad_info(k, v)
|
||||
|
||||
if (per_batch_map is None) != (input_columns is None):
|
||||
# These two parameters appear together.
|
||||
raise ValueError("per_batch_map and input_columns need to be passed in together.")
|
||||
|
||||
if input_columns is not None:
|
||||
check_columns(input_columns, "input_columns")
|
||||
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
||||
raise ValueError("the signature of per_batch_map should match with input columns")
|
||||
|
||||
if (per_batch_map is None) != (input_columns is None):
|
||||
# These two parameters appear together.
|
||||
raise ValueError("per_batch_map and input_columns need to be passed in together.")
|
||||
|
||||
if output_columns is not None:
|
||||
raise ValueError("output_columns is currently not implemented.")
|
||||
|
||||
|
|
|
@ -466,6 +466,16 @@ def test_batch_exception_13():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "column_order is currently not implemented." in str(e)
|
||||
|
||||
def test_batch_exception_14():
|
||||
batch_size = 2
|
||||
input_columns = ["num"]
|
||||
data1 = ds.TFRecordDataset(DATA_DIR)
|
||||
try:
|
||||
_ = data1.batch(batch_size=batch_size, input_columns=input_columns)
|
||||
except ValueError as e:
|
||||
assert "per_batch_map and input_columns need to be passed in together." in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_batch_01()
|
||||
test_batch_02()
|
||||
|
@ -491,4 +501,5 @@ if __name__ == '__main__':
|
|||
test_batch_exception_11()
|
||||
test_batch_exception_12()
|
||||
test_batch_exception_13()
|
||||
test_batch_exception_14()
|
||||
logger.info('\n')
|
||||
|
|
Loading…
Reference in New Issue