forked from mindspore-Ecosystem/mindspore
commit
ffc8a3c362
|
@ -39,6 +39,7 @@
|
||||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||||
#include "dataset/kernels/data/fill_op.h"
|
#include "dataset/kernels/data/fill_op.h"
|
||||||
#include "dataset/kernels/data/mask_op.h"
|
#include "dataset/kernels/data/mask_op.h"
|
||||||
|
#include "dataset/kernels/data/pad_end_op.h"
|
||||||
#include "dataset/kernels/data/slice_op.h"
|
#include "dataset/kernels/data/slice_op.h"
|
||||||
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
|
#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||||
#include "dataset/kernels/data/type_cast_op.h"
|
#include "dataset/kernels/data/type_cast_op.h"
|
||||||
|
@ -444,6 +445,10 @@ void bindTensorOps2(py::module *m) {
|
||||||
py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
|
py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
|
||||||
py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
|
py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
|
||||||
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
|
py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
|
||||||
|
|
||||||
|
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(
|
||||||
|
*m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.")
|
||||||
|
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
void bindTensorOps3(py::module *m) {
|
void bindTensorOps3(py::module *m) {
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_library(kernels-data OBJECT
|
add_library(kernels-data OBJECT
|
||||||
data_utils.cc
|
data_utils.cc
|
||||||
one_hot_op.cc
|
one_hot_op.cc
|
||||||
type_cast_op.cc
|
pad_end_op.cc
|
||||||
to_float16_op.cc
|
type_cast_op.cc
|
||||||
fill_op.cc
|
to_float16_op.cc
|
||||||
slice_op.cc
|
fill_op.cc
|
||||||
mask_op.cc
|
slice_op.cc
|
||||||
)
|
mask_op.cc
|
||||||
|
)
|
||||||
|
|
|
@ -347,8 +347,10 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
|
CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
|
||||||
"Source and pad_value tensors are not of the same type.");
|
"Source and pad_value tensors are not of the same type.");
|
||||||
if (pad_val->type().IsNumeric()) {
|
if (pad_val->type().IsNumeric()) {
|
||||||
|
std::shared_ptr<Tensor> float_pad_value;
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
|
||||||
float val = 0;
|
float val = 0;
|
||||||
RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {}));
|
RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
|
||||||
return PadEndNumeric(src, dst, pad_shape, val);
|
return PadEndNumeric(src, dst, pad_shape, val);
|
||||||
}
|
}
|
||||||
std::string_view val;
|
std::string_view val;
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "dataset/kernels/data/pad_end_op.h"
|
||||||
|
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/kernels/data/data_utils.h"
|
||||||
|
#include "dataset/kernels/tensor_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
Status PadEndOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PadEndOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||||
|
outputs.clear();
|
||||||
|
for (auto s : inputs) {
|
||||||
|
outputs.emplace_back(TensorShape(output_shape_.AsVector()));
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_
|
||||||
|
#define DATASET_KERNELS_DATA_PAD_END_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/kernels/tensor_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class PadEndOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value)
|
||||||
|
: output_shape_(pad_shape), pad_val_(pad_value) {}
|
||||||
|
|
||||||
|
~PadEndOp() override = default;
|
||||||
|
|
||||||
|
void Print(std::ostream &out) const override { out << "PadEndOp"; }
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TensorShape output_shape_;
|
||||||
|
std::shared_ptr<Tensor> pad_val_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_
|
|
@ -22,7 +22,7 @@ import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op
|
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, check_pad_end
|
||||||
from ..core.datatypes import mstype_to_detype
|
from ..core.datatypes import mstype_to_detype
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class Fill(cde.FillOp):
|
||||||
The output tensor will have the same shape and type as the input tensor.
|
The output tensor will have the same shape and type as the input tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fill_value (python types (str, int, float, or bool)) : scalar value
|
fill_value (python types (str, bytes, int, float, or bool)) : scalar value
|
||||||
to fill created tensor with.
|
to fill created tensor with.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -158,3 +158,32 @@ class Mask(cde.MaskOp):
|
||||||
dtype = mstype_to_detype(dtype)
|
dtype = mstype_to_detype(dtype)
|
||||||
constant = cde.Tensor(np.array(constant))
|
constant = cde.Tensor(np.array(constant))
|
||||||
super().__init__(DE_C_RELATIONAL[operator], constant, dtype)
|
super().__init__(DE_C_RELATIONAL[operator], constant, dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class PadEnd(cde.PadEndOp):
|
||||||
|
"""
|
||||||
|
Pad input tensor according to `pad_shape`, need to have same rank.
|
||||||
|
Args:
|
||||||
|
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
|
||||||
|
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
|
||||||
|
pad_value (str, bytes, int, float, or bool, optional): value used to pad. Default to 0 or empty string in case
|
||||||
|
of Tensors of strings.
|
||||||
|
Examples:
|
||||||
|
>>> # Data before
|
||||||
|
>>> # | col |
|
||||||
|
>>> # +---------+
|
||||||
|
>>> # | [1,2,3] |
|
||||||
|
>>> # +---------|
|
||||||
|
>>> data = data.map(operations=PadEnd(pad_shape=[4], pad_value=10))
|
||||||
|
>>> # Data after
|
||||||
|
>>> # | col |
|
||||||
|
>>> # +------------+
|
||||||
|
>>> # | [1,2,3,10] |
|
||||||
|
>>> # +------------|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_pad_end
|
||||||
|
def __init__(self, pad_shape, pad_value=None):
|
||||||
|
if pad_value is not None:
|
||||||
|
pad_value = cde.Tensor(np.array(pad_value))
|
||||||
|
super().__init__(cde.TensorShape(pad_shape), pad_value)
|
||||||
|
|
|
@ -169,8 +169,8 @@ def check_fill_value(method):
|
||||||
fill_value = kwargs.get("fill_value")
|
fill_value = kwargs.get("fill_value")
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
raise ValueError("fill_value is not provided.")
|
raise ValueError("fill_value is not provided.")
|
||||||
if not isinstance(fill_value, (str, float, bool, int)):
|
if not isinstance(fill_value, (str, float, bool, int, bytes)):
|
||||||
raise TypeError("fill_value must be either a primitive python str, float, bool, or int")
|
raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int")
|
||||||
kwargs["fill_value"] = fill_value
|
kwargs["fill_value"] = fill_value
|
||||||
|
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
@ -237,8 +237,8 @@ def check_mask_op(method):
|
||||||
if not isinstance(operator, Relational):
|
if not isinstance(operator, Relational):
|
||||||
raise TypeError("operator is not a Relational operator enum.")
|
raise TypeError("operator is not a Relational operator enum.")
|
||||||
|
|
||||||
if not isinstance(constant, (str, float, bool, int)):
|
if not isinstance(constant, (str, float, bool, int, bytes)):
|
||||||
raise TypeError("constant must be either a primitive python str, float, bool, or int")
|
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
|
||||||
|
|
||||||
if not isinstance(dtype, typing.Type):
|
if not isinstance(dtype, typing.Type):
|
||||||
raise TypeError("dtype is not a MindSpore data type.")
|
raise TypeError("dtype is not a MindSpore data type.")
|
||||||
|
@ -250,3 +250,35 @@ def check_mask_op(method):
|
||||||
return method(self, **kwargs)
|
return method(self, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_pad_end(method):
|
||||||
|
"""Wrapper method to check the parameters of PadEnd."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
pad_shape, pad_value = (list(args) + 2 * [None])[:2]
|
||||||
|
if "pad_shape" in kwargs:
|
||||||
|
pad_shape = kwargs.get("pad_shape")
|
||||||
|
if "pad_value" in kwargs:
|
||||||
|
pad_value = kwargs.get("pad_value")
|
||||||
|
|
||||||
|
if pad_shape is None:
|
||||||
|
raise ValueError("pad_shape is not provided.")
|
||||||
|
|
||||||
|
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
|
||||||
|
raise TypeError("pad_value must be either a primitive python str, float, bool, bytes or int")
|
||||||
|
|
||||||
|
if not isinstance(pad_shape, list):
|
||||||
|
raise TypeError("pad_shape must be a list")
|
||||||
|
|
||||||
|
for dim in pad_shape:
|
||||||
|
if dim is not None:
|
||||||
|
check_pos_int64(dim)
|
||||||
|
|
||||||
|
kwargs["pad_shape"] = pad_shape
|
||||||
|
kwargs["pad_value"] = pad_value
|
||||||
|
|
||||||
|
return method(self, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -27,6 +27,7 @@ SET(DE_UT_SRCS
|
||||||
memory_pool_test.cc
|
memory_pool_test.cc
|
||||||
normalize_op_test.cc
|
normalize_op_test.cc
|
||||||
one_hot_op_test.cc
|
one_hot_op_test.cc
|
||||||
|
pad_end_op_test.cc
|
||||||
path_test.cc
|
path_test.cc
|
||||||
project_op_test.cc
|
project_op_test.cc
|
||||||
queue_test.cc
|
queue_test.cc
|
||||||
|
@ -74,6 +75,8 @@ SET(DE_UT_SRCS
|
||||||
gnn_graph_test.cc
|
gnn_graph_test.cc
|
||||||
coco_op_test.cc
|
coco_op_test.cc
|
||||||
fill_op_test.cc
|
fill_op_test.cc
|
||||||
|
mask_test.cc
|
||||||
|
trucate_pair_test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "dataset/kernels/data/pad_end_op.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
|
class MindDataTestPadEndOp : public UT::Common {
|
||||||
|
protected:
|
||||||
|
MindDataTestPadEndOp() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPadEndOp, TestOp) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPadEndOp.";
|
||||||
|
|
||||||
|
// first set of testunits for numeric values
|
||||||
|
|
||||||
|
TensorShape pad_data_shape({1});
|
||||||
|
|
||||||
|
// prepare input tensor
|
||||||
|
float_t orig1[4] = {1, 1, 1, 1};
|
||||||
|
TensorShape input_shape1({2, 2});
|
||||||
|
std::vector<TensorShape> input_shape1_vector = {input_shape1};
|
||||||
|
std::shared_ptr<Tensor> input1 =
|
||||||
|
std::make_shared<Tensor>(input_shape1, DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(orig1));
|
||||||
|
|
||||||
|
// pad_shape
|
||||||
|
TensorShape pad_shape1[3] = {TensorShape({3, 3}), TensorShape({2, 4}), TensorShape({4, 2})};
|
||||||
|
|
||||||
|
// value to pad
|
||||||
|
float_t pad_data1[3][1] = {0, 3.5, 3.5};
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> expected1[3];
|
||||||
|
|
||||||
|
// expected tensor output for testunit 1
|
||||||
|
float_t out1[9] = {1, 1, 0, 1, 1, 0, 0, 0, 0};
|
||||||
|
|
||||||
|
expected1[0] =
|
||||||
|
std::make_shared<Tensor>(pad_shape1[0], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out1));
|
||||||
|
|
||||||
|
// expected tensor output for testunit 2
|
||||||
|
float_t out2[8] = {1, 1, 3.5, 3.5, 1, 1, 3.5, 3.5};
|
||||||
|
|
||||||
|
expected1[1] =
|
||||||
|
std::make_shared<Tensor>(pad_shape1[1], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out2));
|
||||||
|
|
||||||
|
// expected tensor output for testunit 3
|
||||||
|
float_t out3[8] = {1, 1, 1, 1, 3.5, 3.5, 3.5, 3.5};
|
||||||
|
|
||||||
|
expected1[2] =
|
||||||
|
std::make_shared<Tensor>(pad_shape1[2], DataType(DataType::DE_FLOAT32), reinterpret_cast<unsigned char *>(out3));
|
||||||
|
|
||||||
|
// run the PadEndOp
|
||||||
|
for (auto i = 0; i < 3; i++) {
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||||
|
std::shared_ptr<Tensor> pad_value1 = std::make_shared<Tensor>(pad_data_shape, DataType(DataType::DE_FLOAT32),
|
||||||
|
reinterpret_cast<unsigned char *>(pad_data1[i]));
|
||||||
|
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape1[i], pad_value1));
|
||||||
|
Status s = op->Compute(input1, &output);
|
||||||
|
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output->shape() == expected1[i]->shape());
|
||||||
|
ASSERT_TRUE(output->type() == expected1[i]->type());
|
||||||
|
MS_LOG(DEBUG) << *output << std::endl;
|
||||||
|
MS_LOG(DEBUG) << *expected1[i] << std::endl;
|
||||||
|
ASSERT_TRUE(*output == *expected1[i]);
|
||||||
|
|
||||||
|
s = op->OutputShape(input_shape1_vector, output_shape);
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output_shape.size() == 1);
|
||||||
|
ASSERT_TRUE(output->shape() == output_shape[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// second set of testunits for string
|
||||||
|
|
||||||
|
// input tensor
|
||||||
|
std::vector<std::string> orig2 = {"this", "is"};
|
||||||
|
TensorShape input_shape2({2});
|
||||||
|
std::vector<TensorShape> input_shape2_vector = {input_shape2};
|
||||||
|
std::shared_ptr<Tensor> input2;
|
||||||
|
Tensor::CreateTensor(&input2, orig2, input_shape2);
|
||||||
|
|
||||||
|
// pad_shape
|
||||||
|
TensorShape pad_shape2[3] = {TensorShape({5}), TensorShape({2}), TensorShape({10})};
|
||||||
|
|
||||||
|
// pad value
|
||||||
|
std::vector<std::string> pad_data2[3] = {{""}, {"P"}, {" "}};
|
||||||
|
std::shared_ptr<Tensor> pad_value2[3];
|
||||||
|
|
||||||
|
// expected output for 3 testunits
|
||||||
|
std::shared_ptr<Tensor> expected2[3];
|
||||||
|
std::vector<std::string> outstring[3] = {
|
||||||
|
{"this", "is", "", "", ""}, {"this", "is"}, {"this", "is", " ", " ", " ", " ", " ", " ", " ", " "}};
|
||||||
|
|
||||||
|
for (auto i = 0; i < 3; i++) {
|
||||||
|
// pad value
|
||||||
|
Tensor::CreateTensor(&pad_value2[i], pad_data2[i], pad_data_shape);
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> output;
|
||||||
|
std::vector<TensorShape> output_shape = {TensorShape({})};
|
||||||
|
|
||||||
|
std::unique_ptr<PadEndOp> op(new PadEndOp(pad_shape2[i], pad_value2[i]));
|
||||||
|
|
||||||
|
Status s = op->Compute(input2, &output);
|
||||||
|
|
||||||
|
Tensor::CreateTensor(&expected2[i], outstring[i], pad_shape2[i]);
|
||||||
|
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output->shape() == expected2[i]->shape());
|
||||||
|
ASSERT_TRUE(output->type() == expected2[i]->type());
|
||||||
|
MS_LOG(DEBUG) << *output << std::endl;
|
||||||
|
MS_LOG(DEBUG) << *expected2[i] << std::endl;
|
||||||
|
ASSERT_TRUE(*output == *expected2[i]);
|
||||||
|
|
||||||
|
s = op->OutputShape(input_shape2_vector, output_shape);
|
||||||
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
ASSERT_TRUE(output_shape.size() == 1);
|
||||||
|
ASSERT_TRUE(output->shape() == output_shape[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "MindDataTestPadEndOp end.";
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing PadEnd op in DE
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as ops
|
||||||
|
|
||||||
|
|
||||||
|
def pad_compare(array, pad_shape, pad_value, res):
|
||||||
|
data = ds.NumpySlicesDataset([array])
|
||||||
|
if pad_value is not None:
|
||||||
|
data = data.map(operations=ops.PadEnd(pad_shape, pad_value))
|
||||||
|
else:
|
||||||
|
data = data.map(operations=ops.PadEnd(pad_shape))
|
||||||
|
for d in data:
|
||||||
|
np.testing.assert_array_equal(res, d[0])
|
||||||
|
|
||||||
|
|
||||||
|
# Extensive testing of PadEnd is already done in batch with Pad test cases
|
||||||
|
|
||||||
|
def test_pad_end_basics():
|
||||||
|
pad_compare([1, 2], [3], -1, [1, 2, -1])
|
||||||
|
pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
|
||||||
|
pad_compare([1, 2, 3], [2], -1, [1, 2])
|
||||||
|
pad_compare([1, 2, 3], [5], None, [1, 2, 3, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_end_str():
|
||||||
|
pad_compare([b"1", b"2"], [3], b"-1", [b"1", b"2", b"-1"])
|
||||||
|
pad_compare([b"1", b"2", b"3"], [3], b"-1", [b"1", b"2", b"3"])
|
||||||
|
pad_compare([b"1", b"2", b"3"], [2], b"-1", [b"1", b"2"])
|
||||||
|
pad_compare([b"1", b"2", b"3"], [5], None, [b"1", b"2", b"3", b"", b""])
|
||||||
|
|
||||||
|
|
||||||
|
def test_pad_end_exceptions():
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
pad_compare([1, 2], [3], "-1", [])
|
||||||
|
assert "Source and pad_value tensors are not of the same type." in str(info.value)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
|
||||||
|
assert "Source and pad_value tensors are not of the same type." in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_pad_end_basics()
|
||||||
|
test_pad_end_str()
|
||||||
|
test_pad_end_exceptions()
|
Loading…
Reference in New Issue