!2274 add PadEndOp

Merge pull request !2274 from xunxue/padend
This commit is contained in:
mindspore-ci-bot 2020-06-19 05:05:14 +08:00 committed by Gitee
commit ffc8a3c362
10 changed files with 378 additions and 15 deletions

View File

@ -39,6 +39,7 @@
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/pad_end_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "dataset/kernels/data/type_cast_op.h"
@ -444,6 +445,10 @@ void bindTensorOps2(py::module *m) {
py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
}
void bindTensorOps3(py::module *m) {

View File

@ -1,11 +1,12 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-data OBJECT
data_utils.cc
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc
mask_op.cc
)
data_utils.cc
one_hot_op.cc
pad_end_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc
mask_op.cc
)

View File

@ -347,8 +347,10 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
"Source and pad_value tensors are not of the same type.");
if (pad_val->type().IsNumeric()) {
std::shared_ptr<Tensor> float_pad_value;
RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
float val = 0;
RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {}));
RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
return PadEndNumeric(src, dst, pad_shape, val);
}
std::string_view val;

View File

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/pad_end_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
Status PadEndOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_);
return s;
}
Status PadEndOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
for (auto s : inputs) {
outputs.emplace_back(TensorShape(output_shape_.AsVector()));
}
CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape");
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_
#define DATASET_KERNELS_DATA_PAD_END_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class PadEndOp : public TensorOp {
public:
explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
: output_shape_(pad_shape), pad_val_(pad_value) {}
~PadEndOp() override = default;
void Print(std::ostream &out) const override { out << "PadEndOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
private:
TensorShape output_shape_;
std::shared_ptr<Tensor> pad_val_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_

View File

@ -22,7 +22,7 @@ import mindspore._c_dataengine as cde
import numpy as np
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, check_pad_end
from ..core.datatypes import mstype_to_detype
@ -46,7 +46,7 @@ class Fill(cde.FillOp):
The output tensor will have the same shape and type as the input tensor.
Args:
fill_value (python types (str, int, float, or bool)) : scalar value
fill_value (python types (str, bytes, int, float, or bool)) : scalar value
to fill created tensor with.
"""
@ -158,3 +158,32 @@ class Mask(cde.MaskOp):
dtype = mstype_to_detype(dtype)
constant = cde.Tensor(np.array(constant))
super().__init__(DE_C_RELATIONAL[operator], constant, dtype)
class PadEnd(cde.PadEndOp):
"""
Pad input tensor according to `pad_shape`, need to have same rank.
Args:
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
pad_value (str, bytes, int, float, or bool, optional): value used to pad. Default to 0 or empty string in case
of Tensors of strings.
Examples:
>>> # Data before
>>> # | col |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------|
>>> data = data.map(operations=PadEnd(pad_shape=[4], pad_value=10))
>>> # Data after
>>> # | col |
>>> # +------------+
>>> # | [1,2,3,10] |
>>> # +------------|
"""
@check_pad_end
def __init__(self, pad_shape, pad_value=None):
if pad_value is not None:
pad_value = cde.Tensor(np.array(pad_value))
super().__init__(cde.TensorShape(pad_shape), pad_value)

View File

@ -169,8 +169,8 @@ def check_fill_value(method):
fill_value = kwargs.get("fill_value")
if fill_value is None:
raise ValueError("fill_value is not provided.")
if not isinstance(fill_value, (str, float, bool, int)):
raise TypeError("fill_value must be either a primitive python str, float, bool, or int")
if not isinstance(fill_value, (str, float, bool, int, bytes)):
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
kwargs["fill_value"] = fill_value
return method(self, **kwargs)
@ -237,8 +237,8 @@ def check_mask_op(method):
if not isinstance(operator, Relational):
raise TypeError("operator is not a Relational operator enum.")
if not isinstance(constant, (str, float, bool, int)):
raise TypeError("constant must be either a primitive python str, float, bool, or int")
if not isinstance(constant, (str, float, bool, int, bytes)):
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
@ -250,3 +250,35 @@ def check_mask_op(method):
return method(self, **kwargs)
return new_method
def check_pad_end(method):
"""Wrapper method to check the parameters of PadEnd."""
@wraps(method)
def new_method(self, *args, **kwargs):
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
if "pad_shape" in kwargs:
pad_shape = kwargs.get("pad_shape")
if "pad_value" in kwargs:
pad_value = kwargs.get("pad_value")
if pad_shape is None:
raise ValueError("pad_shape is not provided.")
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, bytes or int")
if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list")
for dim in pad_shape:
if dim is not None:
check_pos_int64(dim)
kwargs["pad_shape"] = pad_shape
kwargs["pad_value"] = pad_value
return method(self, **kwargs)
return new_method

View File

@ -27,6 +27,7 @@ SET(DE_UT_SRCS
memory_pool_test.cc
normalize_op_test.cc
one_hot_op_test.cc
pad_end_op_test.cc
path_test.cc
project_op_test.cc
queue_test.cc
@ -74,6 +75,8 @@ SET(DE_UT_SRCS
gnn_graph_test.cc
coco_op_test.cc
fill_op_test.cc
mask_test.cc
trucate_pair_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})

View File

@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "dataset/kernels/data/pad_end_op.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestPadEndOp : public UT::Common {
protected:
MindDataTestPadEndOp() {}
};
TEST_F(MindDataTestPadEndOp, TestOp) {
MS_LOG(INFO) << "Doing MindDataTestPadEndOp.";
// first set of testunits for numeric values
TensorShape pad_data_shape({1});
// prepare input tensor
float_t orig1[4] = {1, 1, 1, 1};
TensorShape input_shape1({2, 2});
std::vector<TensorShape> input_shape1_vector = {input_shape1};
std::shared_ptr<Tensor> input1 =
std::make_shared<Tensor>(input_shape1, DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(orig1));
// pad_shape
TensorShape pad_shape1[3] = {TensorShape({3, 3}), TensorShape({2, 4}), TensorShape({4, 2})};
// value to pad
float_t pad_data1[3][1] = {0, 3.5, 3.5};
std::shared_ptr<Tensor> expected1[3];
// expected tensor output for testunit 1
float_t out1[9] = {1, 1, 0, 1, 1, 0, 0, 0, 0};
expected1[0] =
std::make_shared<Tensor>(pad_shape1[0], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out1));
// expected tensor output for testunit 2
float_t out2[8] = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5};
expected1[1] =
std::make_shared<Tensor>(pad_shape1[1], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out2));
// expected tensor output for testunit 3
float_t out3[8] = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5};
expected1[2] =
std::make_shared<Tensor>(pad_shape1[2], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out3));
// run the PadEndOp
for (auto i = 0; i < 3; i++) {
std::shared_ptr<Tensor> output;
std::vector<TensorShape> output_shape = {TensorShape({})};
std::shared_ptr<Tensor> pad_value1 = std::make_shared<Tensor>(pad_data_shape, DataType(DataType::DE_FLOAT32),
reinterpret_cast<unsigned char *>(pad_data1[i]));
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape1[i], pad_value1));
Status s = op->Compute(input1, &output);
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output->shape() == expected1[i]->shape());
ASSERT_TRUE(output->type() == expected1[i]->type());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected1[i] << std::endl;
ASSERT_TRUE(*output == *expected1[i]);
s = op->OutputShape(input_shape1_vector, output_shape);
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output_shape.size() == 1);
ASSERT_TRUE(output->shape() == output_shape[0]);
}
// second set of testunits for string
// input tensor
std::vector<std::string> orig2 = {"this", "is"};
TensorShape input_shape2({2});
std::vector<TensorShape> input_shape2_vector = {input_shape2};
std::shared_ptr<Tensor> input2;
Tensor::CreateTensor(&input2, orig2, input_shape2);
// pad_shape
TensorShape pad_shape2[3] = {TensorShape({5}), TensorShape({2}), TensorShape({10})};
// pad value
std::vector<std::string> pad_data2[3] = {{""}, {"P"}, {" "}};
std::shared_ptr<Tensor> pad_value2[3];
// expected output for 3 testunits
std::shared_ptr<Tensor> expected2[3];
std::vector<std::string> outstring[3] = {
{"this", "is", "", "", ""}, {"this", "is"}, {"this", "is", " ", " ", " ", " ", " ", " ", " ", " "}};
for (auto i = 0; i < 3; i++) {
// pad value
Tensor::CreateTensor(&pad_value2[i], pad_data2[i], pad_data_shape);
std::shared_ptr<Tensor> output;
std::vector<TensorShape> output_shape = {TensorShape({})};
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape2[i], pad_value2[i]));
Status s = op->Compute(input2, &output);
Tensor::CreateTensor(&expected2[i], outstring[i], pad_shape2[i]);
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output->shape() == expected2[i]->shape());
ASSERT_TRUE(output->type() == expected2[i]->type());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected2[i] << std::endl;
ASSERT_TRUE(*output == *expected2[i]);
s = op->OutputShape(input_shape2_vector, output_shape);
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output_shape.size() == 1);
ASSERT_TRUE(output->shape() == output_shape[0]);
}
MS_LOG(INFO) << "MindDataTestPadEndOp end.";
}

View File

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing PadEnd op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
def pad_compare(array, pad_shape, pad_value, res):
data = ds.NumpySlicesDataset([array])
if pad_value is not None:
data = data.map(operations=ops.PadEnd(pad_shape, pad_value))
else:
data = data.map(operations=ops.PadEnd(pad_shape))
for d in data:
np.testing.assert_array_equal(res, d[0])
# Extensive testing of PadEnd is already done in batch with Pad test cases
def test_pad_end_basics():
pad_compare([1, 2], [3], -1, [1, 2, -1])
pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
pad_compare([1, 2, 3], [2], -1, [1, 2])
pad_compare([1, 2, 3], [5], None, [1, 2, 3, 0, 0])
def test_pad_end_str():
pad_compare([b"1", b"2"], [3], b"-1", [b"1", b"2", b"-1"])
pad_compare([b"1", b"2", b"3"], [3], b"-1", [b"1", b"2", b"3"])
pad_compare([b"1", b"2", b"3"], [2], b"-1", [b"1", b"2"])
pad_compare([b"1", b"2", b"3"], [5], None, [b"1", b"2", b"3", b"", b""])
def test_pad_end_exceptions():
with pytest.raises(RuntimeError) as info:
pad_compare([1, 2], [3], "-1", [])
assert "Source and pad_value tensors are not of the same type." in str(info.value)
with pytest.raises(RuntimeError) as info:
pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
assert "Source and pad_value tensors are not of the same type." in str(info.value)
if __name__ == "__main__":
test_pad_end_basics()
test_pad_end_str()
test_pad_end_exceptions()