Fix test team issue: ToNumber in eager mode.

And support multiple tensor as input in call method

Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
This commit is contained in:
alex-yuyue 2021-03-02 17:25:50 -05:00
parent 7f71525b0a
commit 66aff93940
8 changed files with 75 additions and 41 deletions

View File

@ -40,6 +40,8 @@ Status TensorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
Status TensorOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
if (OneToOne()) {
if (input.size() != 1)
return Status(StatusCode::kMDUnexpectedError, "The op is OneToOne, can only accept one tensor as input.");
output->resize(1);
return Compute(input[0], &(*output)[0]);
}

View File

@ -63,28 +63,23 @@ class TextTensorOperation(TensorOperation):
"""
Base class of Text Tensor Ops
"""
def __call__(self, input_tensor):
if not isinstance(input_tensor, list):
input_list = [input_tensor]
else:
input_list = input_tensor
tensor_list = []
for tensor in input_list:
if not isinstance(tensor, str):
raise TypeError("Input should be string or list of strings, got {}.".format(type(tensor)))
tensor_list.append(cde.Tensor(np.asarray(tensor)))
def __call__(self, *tensor_list):
tensor_array = []
output_list = []
# Combine input tensor_list to a TensorRow
for input_tensor in tensor_list:
if not isinstance(input_tensor, (str, list)):
raise TypeError("Input should be string or list of strings, got {}.".format(type(input_tensor)))
tensor_array.append(cde.Tensor(np.asarray(input_tensor)))
callable_op = cde.Execute(self.parse())
output_list = callable_op(tensor_list)
output_list = callable_op(tensor_array)
for i, element in enumerate(output_list):
arr = element.as_array()
if arr.dtype.char == 'S':
output_list[i] = to_str(arr)
output_list[i] = np.char.decode(arr)
else:
output_list[i] = arr
if not isinstance(input_tensor, list) and len(output_list) == 1:
output_list = output_list[0]
return output_list
return output_list[0] if len(output_list) == 1 else output_list
def parse(self):
raise NotImplementedError("TextTensorOperation has to implement parse() method.")

View File

@ -62,28 +62,24 @@ class ImageTensorOperation(TensorOperation):
"""
Base class of Image Tensor Ops
"""
def __call__(self, input_tensor):
if not isinstance(input_tensor, list):
input_list = [input_tensor]
else:
input_list = input_tensor
tensor_list = []
for tensor in input_list:
if not isinstance(tensor, (np.ndarray, Image.Image)):
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(tensor)))
tensor_list.append(cde.Tensor(np.asarray(tensor)))
def __call__(self, *tensor_list):
tensor_array = []
output_list = []
# Combine input tensor_list to a TensorRow
for input_tensor in tensor_list:
if not isinstance(input_tensor, (np.ndarray, Image.Image)):
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(input_tensor)))
tensor_array.append(cde.Tensor(np.asarray(input_tensor)))
callable_op = cde.Execute(self.parse())
output_list = callable_op(tensor_list)
output_list = callable_op(tensor_array)
for i, element in enumerate(output_list):
arr = element.as_array()
if arr.dtype.char == 'S':
output_list[i] = np.char.decode(arr)
else:
output_list[i] = arr
if not isinstance(input_tensor, list) and len(output_list) == 1:
output_list = output_list[0]
return output_list
return output_list[0] if len(output_list) == 1 else output_list
def parse(self):
raise NotImplementedError("ImageTensorOperation has to implement parse() method.")

View File

@ -16,6 +16,7 @@
Testing HWC2CHW op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms
import mindspore.dataset.vision.c_transforms as c_vision
@ -36,11 +37,23 @@ def test_HWC2CHW_callable():
logger.info("Test HWC2CHW callable")
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
img = c_vision.Decode()(img)
img = c_vision.HWC2CHW()(img)
logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape))
assert img.shape == (3, 2268, 4032)
assert img.shape == (2268, 4032, 3)
# test one tensor
img1 = c_vision.HWC2CHW()(img)
assert img1.shape == (3, 2268, 4032)
# test input multiple tensors
with pytest.raises(RuntimeError) as info:
imgs = [img, img]
_ = c_vision.HWC2CHW()(*imgs)
assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
with pytest.raises(RuntimeError) as info:
_ = c_vision.HWC2CHW()(img, img)
assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
def test_HWC2CHW(plot=False):

View File

@ -44,10 +44,12 @@ def test_random_crop_and_resize_callable():
decode_op = c_vision.Decode()
img = decode_op(img)
assert img.shape == (2268, 4032, 3)
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
img = random_crop_and_resize_op(img)
assert np.shape(img) == (256, 512, 3)
# test one tensor
random_crop_and_resize_op1 = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
img1 = random_crop_and_resize_op1(img)
assert img1.shape == (256, 512, 3)
def test_random_crop_and_resize_op_c(plot=False):

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore.dataset.text import JiebaTokenizer
from mindspore.dataset.text import JiebaMode, to_str
@ -33,14 +34,19 @@ def test_jieba_callable():
jieba_op1 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP)
jieba_op2 = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM)
# test one tensor
text1 = "今天天气太好了我们一起去外面玩吧"
text2 = "男默女泪市长江大桥"
assert np.array_equal(jieba_op1(text1), ['今天天气', '太好了', '我们', '一起', '', '外面', '玩吧'])
assert np.array_equal(jieba_op2(text1), ['今天', '天气', '', '', '', '我们', '一起', '', '外面', '', ''])
jieba_op1.add_word("男默女泪")
assert np.array_equal(jieba_op1(text2), ['男默女泪', '', '长江大桥'])
# test input multiple tensors
with pytest.raises(RuntimeError) as info:
_ = jieba_op1(text1, text2)
assert "JiebaTokenizer: input only support one column data." in str(info.value)
def test_jieba_1():
"""Test jieba tokenizer with MP mode"""

View File

@ -33,6 +33,24 @@ def string_dataset_generator(strings):
yield (np.array(string, dtype='S'),)
def test_to_number_eager():
"""
Test ToNumber op is callable
"""
input_strings = [["1", "2", "3"], ["4", "5", "6"]]
op = text.ToNumber(mstype.int8)
# test input_strings as one 2D tensor
result1 = op(input_strings) # np array: [[1 2 3] [4 5 6]]
assert np.array_equal(result1, np.array([[1, 2, 3], [4, 5, 6]], dtype='i'))
# test input multiple tensors
with pytest.raises(RuntimeError) as info:
# test input_strings as two 1D tensor. It's error because to_number is an OneToOne op
_ = op(*input_strings)
assert "The op is OneToOne, can only accept one tensor as input." in str(info.value)
def test_to_number_typical_case_integral():
input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
["-1726483716", "98921728421"]]
@ -194,6 +212,7 @@ def test_to_number_invalid_type():
if __name__ == '__main__':
test_to_number_eager()
test_to_number_typical_case_integral()
test_to_number_typical_case_non_integral()
test_to_number_boundaries_integral()

View File

@ -38,12 +38,13 @@ def test_uniform_augment_callable(num_ops=2):
decode_op = C.Decode()
img = decode_op(img)
assert img.shape == (2268, 4032, 3)
transforms_ua = [C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]),
C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])]
uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops)
img = uni_aug([img, img])
assert ((np.shape(img) == (2, 2268, 4032, 3)) or (np.shape(img) == (1, 400, 400, 3)))
img = uni_aug(img)
assert img.shape == (2268, 4032, 3) or img.shape == (400, 400, 3)
def test_uniform_augment(plot=False, num_ops=2):