forked from mindspore-Ecosystem/mindspore
Mask Op
This commit is contained in:
parent
e2012a1de9
commit
f2462bb00d
|
@ -38,6 +38,7 @@
|
|||
#include "dataset/kernels/image/resize_op.h"
|
||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||
#include "dataset/kernels/data/fill_op.h"
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
#include "dataset/kernels/data/slice_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
@ -369,7 +370,7 @@ void bindTensorOps2(py::module *m) {
|
|||
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
|
||||
.def(py::init<std::shared_ptr<Tensor>>());
|
||||
|
||||
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "")
|
||||
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor Slice operation.")
|
||||
.def(py::init<bool>())
|
||||
.def(py::init([](const py::list &py_list) {
|
||||
std::vector<dsize_t> c_list;
|
||||
|
@ -400,6 +401,19 @@ void bindTensorOps2(py::module *m) {
|
|||
return std::make_shared<SliceOp>(c_slice);
|
||||
}));
|
||||
|
||||
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
|
||||
.value("EQ", RelationalOp::kEqual)
|
||||
.value("NE", RelationalOp::kNotEqual)
|
||||
.value("LT", RelationalOp::kLess)
|
||||
.value("LE", RelationalOp::kLessEqual)
|
||||
.value("GT", RelationalOp::kGreater)
|
||||
.value("GE", RelationalOp::kGreaterEqual)
|
||||
.export_values();
|
||||
|
||||
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
|
||||
"Tensor operation mask using relational comparator")
|
||||
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
|
||||
|
||||
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
|
||||
*m, "RandomRotationOp",
|
||||
"Tensor operation to apply RandomRotation."
|
||||
|
|
|
@ -699,7 +699,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
|
|||
Status Tensor::GetItemAt(std::string_view *o, const std::vector<dsize_t> &index) const {
|
||||
RETURN_UNEXPECTED_IF_NULL(data_);
|
||||
RETURN_UNEXPECTED_IF_NULL(o);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string");
|
||||
|
||||
uchar *start = nullptr;
|
||||
offset_t length = 0;
|
||||
|
@ -932,17 +932,17 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
|
|||
dsize_t out_index = 0;
|
||||
dsize_t dim_length = shape_[0];
|
||||
dsize_t type_size = type_.SizeInBytes();
|
||||
dsize_t src_start = handleNeg(indices[0], dim_length);
|
||||
dsize_t src_start = HandleNeg(indices[0], dim_length);
|
||||
uchar *dst_addr = (*out)->data_;
|
||||
dsize_t count = 1;
|
||||
|
||||
for (dsize_t i = 0; i < indices.size(); i++) {
|
||||
dsize_t cur_index = handleNeg(indices[i], dim_length);
|
||||
dsize_t cur_index = HandleNeg(indices[i], dim_length);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
cur_index >= 0 && cur_index < dim_length,
|
||||
"Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")");
|
||||
if (i < indices.size() - 1) {
|
||||
dsize_t next_index = handleNeg(indices[i + 1], dim_length);
|
||||
dsize_t next_index = HandleNeg(indices[i + 1], dim_length);
|
||||
if (next_index == cur_index + 1) {
|
||||
count++;
|
||||
continue;
|
||||
|
@ -951,7 +951,7 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
|
|||
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
|
||||
out_index += count;
|
||||
if (i < indices.size() - 1) {
|
||||
src_start = handleNeg(indices[i + 1], dim_length); // next index
|
||||
src_start = HandleNeg(indices[i + 1], dim_length); // next index
|
||||
}
|
||||
count = 1;
|
||||
}
|
||||
|
@ -961,7 +961,7 @@ Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize
|
|||
dsize_t dim_length = shape_[0];
|
||||
std::vector<std::string> strings;
|
||||
for (dsize_t index : indices) {
|
||||
dsize_t cur_index = handleNeg(index, dim_length);
|
||||
dsize_t cur_index = HandleNeg(index, dim_length);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
cur_index >= 0 && cur_index < dim_length,
|
||||
"Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")");
|
||||
|
|
|
@ -348,7 +348,7 @@ class Tensor {
|
|||
}
|
||||
|
||||
// Handle negative indices.
|
||||
static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
|
||||
static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
|
||||
|
||||
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
|
||||
// Based on the type of tensor, SliceNumeric or SliceString will be called
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(kernels-data OBJECT
|
||||
data_utils.cc
|
||||
one_hot_op.cc
|
||||
type_cast_op.cc
|
||||
to_float16_op.cc
|
||||
fill_op.cc
|
||||
slice_op.cc)
|
||||
data_utils.cc
|
||||
one_hot_op.cc
|
||||
type_cast_op.cc
|
||||
to_float16_op.cc
|
||||
fill_op.cc
|
||||
slice_op.cc
|
||||
mask_op.cc)
|
||||
|
|
|
@ -120,7 +120,7 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
|
|||
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
|
||||
|
||||
std::shared_ptr<Tensor> fill_output;
|
||||
op->Compute(fill_value, &fill_output);
|
||||
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
|
||||
|
||||
|
@ -344,6 +344,8 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
|||
return PadEndString(src, dst, pad_shape, "");
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
|
||||
"Source and pad_value tensors are not of the same type.");
|
||||
if (pad_val->type().IsNumeric()) {
|
||||
float val = 0;
|
||||
RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {}));
|
||||
|
@ -454,5 +456,102 @@ Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::s
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
|
||||
const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
|
||||
T value;
|
||||
RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
|
||||
auto in_itr = input->begin<T>();
|
||||
auto out_itr = output->begin<bool>();
|
||||
for (; in_itr != input->end<T>(); in_itr++, out_itr++) {
|
||||
switch (op) {
|
||||
case RelationalOp::kEqual:
|
||||
*out_itr = (*in_itr == value);
|
||||
break;
|
||||
case RelationalOp::kNotEqual:
|
||||
*out_itr = (*in_itr != value);
|
||||
break;
|
||||
case RelationalOp::kGreater:
|
||||
*out_itr = (*in_itr > value);
|
||||
break;
|
||||
case RelationalOp::kGreaterEqual:
|
||||
*out_itr = (*in_itr >= value);
|
||||
break;
|
||||
case RelationalOp::kLess:
|
||||
*out_itr = (*in_itr < value);
|
||||
break;
|
||||
case RelationalOp::kLessEqual:
|
||||
*out_itr = (*in_itr <= value);
|
||||
break;
|
||||
default:
|
||||
RETURN_STATUS_UNEXPECTED("Unknown relational operator.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
|
||||
RelationalOp op) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
|
||||
"Cannot convert constant value to the type of the input tensor.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar");
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL)));
|
||||
|
||||
std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
|
||||
std::shared_ptr<Tensor> casted_value;
|
||||
if (input->type().IsNumeric()) {
|
||||
RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
|
||||
} else {
|
||||
casted_value = value;
|
||||
}
|
||||
|
||||
switch (input->type().value()) {
|
||||
case DataType::DE_BOOL:
|
||||
RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_INT8:
|
||||
RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_UINT8:
|
||||
RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_UINT16:
|
||||
RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_INT16:
|
||||
RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_UINT32:
|
||||
RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_INT32:
|
||||
RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_UINT64:
|
||||
RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_INT64:
|
||||
RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_FLOAT16:
|
||||
RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_FLOAT32:
|
||||
RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_FLOAT64:
|
||||
RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_STRING:
|
||||
RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
|
||||
break;
|
||||
case DataType::DE_UNKNOWN:
|
||||
RETURN_STATUS_UNEXPECTED("Unsupported input type.");
|
||||
break;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -119,6 +119,35 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
|
|||
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
|
||||
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
|
||||
const std::string &pad_value);
|
||||
|
||||
enum class RelationalOp {
|
||||
kEqual = 0, // ==
|
||||
kNotEqual, // !=
|
||||
kLess, // <
|
||||
kLessEqual, // <=
|
||||
kGreater, // >
|
||||
kGreaterEqual, // >=
|
||||
};
|
||||
|
||||
/// Helper method that masks the input tensor
|
||||
/// @tparam T type of the tensor
|
||||
/// @param input[in] input tensor
|
||||
/// @param output[out] output tensor
|
||||
/// @param value_tensor[in] scalar tensor value to compared with
|
||||
/// @param op[in] RelationalOp enum
|
||||
/// @return Status ok/error
|
||||
template <typename T>
|
||||
Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
|
||||
const std::shared_ptr<Tensor> &value_tensor, RelationalOp op);
|
||||
|
||||
/// Mask the input tensor
|
||||
/// @param input[in] input tensor
|
||||
/// @param output[out] output tensor
|
||||
/// @param value[in] scalar tensor value to compared with
|
||||
/// @param op[in] RelationalOp enum
|
||||
/// @return Status ok/error
|
||||
Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
|
||||
RelationalOp op);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status MaskOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
std::shared_ptr<Tensor> temp_output;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric.");
|
||||
|
||||
RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_));
|
||||
|
||||
// cast the output to the the required type. Skip casting if type_ is bool.
|
||||
if (type_ != DataType::DE_BOOL) {
|
||||
RETURN_IF_NOT_OK(cast_->Compute(temp_output, output));
|
||||
} else {
|
||||
*output = temp_output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaskOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
outputs[0] = type_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_KERNELS_DATA_MASK_OP_H_
|
||||
#define DATASET_KERNELS_DATA_MASK_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "dataset/kernels/data/type_cast_op.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MaskOp : public TensorOp {
|
||||
public:
|
||||
MaskOp(RelationalOp op, std::shared_ptr<Tensor> value, DataType type = DataType(DataType::DE_BOOL))
|
||||
: op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {}
|
||||
|
||||
~MaskOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override { out << "MaskOp"; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
private:
|
||||
RelationalOp op_;
|
||||
std::shared_ptr<Tensor> value_;
|
||||
DataType type_;
|
||||
std::unique_ptr<TypeCastOp> cast_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_DATA_MASK_OP_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,7 +16,6 @@
|
|||
#include "dataset/kernels/data/slice_op.h"
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -36,8 +36,8 @@ class Slice {
|
|||
|
||||
std::vector<dsize_t> Indices(dsize_t length) {
|
||||
std::vector<dsize_t> indices;
|
||||
dsize_t index = std::min(Tensor::handleNeg(start_, length), length);
|
||||
dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length);
|
||||
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);
|
||||
dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length);
|
||||
if (step_ > 0) {
|
||||
for (; index < end_index; index += step_) {
|
||||
indices.push_back(index);
|
||||
|
@ -80,4 +80,4 @@ class SliceOp : public TensorOp {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_
|
||||
#endif // DATASET_KERNELS_DATA_SLICE_OP_H_
|
||||
|
|
|
@ -15,10 +15,14 @@
|
|||
"""
|
||||
This module c_transforms provides common operations, including OneHotOp and TypeCast.
|
||||
"""
|
||||
import numpy as np
|
||||
from enum import IntEnum
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op
|
||||
import numpy as np
|
||||
|
||||
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
|
||||
|
||||
|
@ -48,7 +52,6 @@ class Fill(cde.FillOp):
|
|||
|
||||
@check_fill_value
|
||||
def __init__(self, fill_value):
|
||||
print(fill_value)
|
||||
super().__init__(cde.Tensor(np.array(fill_value)))
|
||||
|
||||
|
||||
|
@ -108,3 +111,50 @@ class Slice(cde.SliceOp):
|
|||
elif dim0 is Ellipsis:
|
||||
dim0 = True
|
||||
super().__init__(dim0)
|
||||
|
||||
|
||||
class Relational(IntEnum):
|
||||
EQ = 0
|
||||
NE = 1
|
||||
GT = 2
|
||||
GE = 3
|
||||
LT = 4
|
||||
LE = 5
|
||||
|
||||
|
||||
DE_C_RELATIONAL = {Relational.EQ: cde.RelationalOp.EQ,
|
||||
Relational.NE: cde.RelationalOp.NE,
|
||||
Relational.GT: cde.RelationalOp.GT,
|
||||
Relational.GE: cde.RelationalOp.GE,
|
||||
Relational.LT: cde.RelationalOp.LT,
|
||||
Relational.LE: cde.RelationalOp.LE}
|
||||
|
||||
|
||||
class Mask(cde.MaskOp):
|
||||
"""
|
||||
Mask content of the input tensor with the given predicate.
|
||||
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
|
||||
Args:
|
||||
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
|
||||
constant (python types (str, int, float, or bool): constant to be compared to.
|
||||
Constant will be casted to the type of the input tensor
|
||||
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
|
||||
Examples:
|
||||
>>> # Data before
|
||||
>>> # | col1 |
|
||||
>>> # +---------+
|
||||
>>> # | [1,2,3] |
|
||||
>>> # +---------+
|
||||
>>> data = data.map(operations=Mask(Relational.EQ, 2))
|
||||
>>> # Data after
|
||||
>>> # | col1 |
|
||||
>>> # +--------------------+
|
||||
>>> # | [False,True,False] |
|
||||
>>> # +--------------------+
|
||||
"""
|
||||
|
||||
@check_mask_op
|
||||
def __init__(self, operator, constant, dtype=mstype.bool_):
|
||||
dtype = mstype_to_detype(dtype)
|
||||
constant = cde.Tensor(np.array(constant))
|
||||
super().__init__(DE_C_RELATIONAL[operator], constant, dtype)
|
||||
|
|
|
@ -213,3 +213,40 @@ def check_slice_op(method):
|
|||
return method(self, *args)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_mask_op(method):
|
||||
"""Wrapper method to check the parameters of slice."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
operator, constant, dtype = (list(args) + 3 * [None])[:3]
|
||||
if "operator" in kwargs:
|
||||
operator = kwargs.get("operator")
|
||||
if "constant" in kwargs:
|
||||
constant = kwargs.get("constant")
|
||||
if "dtype" in kwargs:
|
||||
dtype = kwargs.get("dtype")
|
||||
|
||||
if operator is None:
|
||||
raise ValueError("operator is not provided.")
|
||||
if constant is None:
|
||||
raise ValueError("constant is not provided.")
|
||||
|
||||
from .c_transforms import Relational
|
||||
if not isinstance(operator, Relational):
|
||||
raise TypeError("operator is not a Relational operator enum.")
|
||||
|
||||
if not isinstance(constant, (str, float, bool, int)):
|
||||
raise TypeError("constant must be either a primitive python str, float, bool, or int")
|
||||
|
||||
if not isinstance(dtype, typing.Type):
|
||||
raise TypeError("dtype is not a MindSpore data type.")
|
||||
|
||||
kwargs["operator"] = operator
|
||||
kwargs["constant"] = constant
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
return method(self, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "securec.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/core/cv_tensor.h"
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/util/de_error.h"
|
||||
#include "dataset/kernels/data/mask_op.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class MindDataTestMaskOp : public UT::Common {
|
||||
public:
|
||||
MindDataTestMaskOp() {}
|
||||
|
||||
void SetUp() { GlobalInit(); }
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestMaskOp, Basics) {
|
||||
std::shared_ptr<Tensor> t;
|
||||
Tensor::CreateTensor(&t, std::vector<uint32_t>({1, 2, 3, 4, 5, 6}));
|
||||
std::shared_ptr<Tensor> v;
|
||||
Tensor::CreateTensor(&v, std::vector<uint32_t>({3}), TensorShape::CreateScalar());
|
||||
std::shared_ptr<MaskOp> op = std::make_shared<MaskOp>(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16));
|
||||
std::shared_ptr<Tensor> out;
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
|
||||
op = std::make_shared<MaskOp>(RelationalOp::kNotEqual, v, DataType(DataType::DE_UINT16));
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
|
||||
op = std::make_shared<MaskOp>(RelationalOp::kLessEqual, v, DataType(DataType::DE_UINT16));
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
|
||||
op = std::make_shared<MaskOp>(RelationalOp::kLess, v, DataType(DataType::DE_UINT16));
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
|
||||
op = std::make_shared<MaskOp>(RelationalOp::kGreaterEqual, v, DataType(DataType::DE_UINT16));
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
|
||||
op = std::make_shared<MaskOp>(RelationalOp::kGreater, v, DataType(DataType::DE_UINT16));
|
||||
ASSERT_TRUE(op->Compute(t, &out).IsOk());
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Mask op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
|
||||
mstype_to_np_type = {
|
||||
mstype.bool_: np.bool,
|
||||
mstype.int8: np.int8,
|
||||
mstype.uint8: np.uint8,
|
||||
mstype.int16: np.int16,
|
||||
mstype.uint16: np.uint16,
|
||||
mstype.int32: np.int32,
|
||||
mstype.uint32: np.uint32,
|
||||
mstype.int64: np.int64,
|
||||
mstype.uint64: np.uint64,
|
||||
mstype.float16: np.float16,
|
||||
mstype.float32: np.float32,
|
||||
mstype.float64: np.float64,
|
||||
mstype.string: np.str
|
||||
}
|
||||
|
||||
|
||||
def mask_compare(array, op, constant, dtype=mstype.bool_):
|
||||
data = ds.NumpySlicesDataset([array])
|
||||
array = np.array(array)
|
||||
data = data.map(operations=ops.Mask(op, constant, dtype))
|
||||
for d in data:
|
||||
if op == ops.Relational.EQ:
|
||||
array = array == np.array(constant, dtype=array.dtype)
|
||||
elif op == ops.Relational.NE:
|
||||
array = array != np.array(constant, dtype=array.dtype)
|
||||
elif op == ops.Relational.GT:
|
||||
array = array > np.array(constant, dtype=array.dtype)
|
||||
elif op == ops.Relational.GE:
|
||||
array = array >= np.array(constant, dtype=array.dtype)
|
||||
elif op == ops.Relational.LT:
|
||||
array = array < np.array(constant, dtype=array.dtype)
|
||||
elif op == ops.Relational.LE:
|
||||
array = array <= np.array(constant, dtype=array.dtype)
|
||||
|
||||
array = array.astype(dtype=mstype_to_np_type[dtype])
|
||||
|
||||
np.testing.assert_array_equal(array, d[0])
|
||||
|
||||
|
||||
def test_int_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
|
||||
|
||||
|
||||
def test_float_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k)
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k)
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k)
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k)
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k)
|
||||
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
|
||||
|
||||
|
||||
def test_float_comparison2():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k)
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
|
||||
|
||||
|
||||
def test_string_comparison():
|
||||
for k in mstype_to_np_type:
|
||||
if k == mstype.string:
|
||||
continue
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k)
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k)
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k)
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k)
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k)
|
||||
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k)
|
||||
|
||||
|
||||
def test_mask_exceptions_str():
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5")
|
||||
assert "Cannot convert constant value to the type of the input tensor." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5)
|
||||
assert "Cannot convert constant value to the type of the input tensor." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string)
|
||||
assert "Cannot generate a string mask. Type should be numeric." in str(info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_int_comparison()
|
||||
test_float_comparison()
|
||||
test_float_comparison2()
|
||||
test_string_comparison()
|
||||
test_mask_exceptions_str()
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing TypeCast op in DE
|
||||
Testing Slice op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -109,6 +109,10 @@ def test_slice_exceptions():
|
|||
slice_compare([1, 2, 3, 4, 5], slice(0))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
@ -182,6 +186,10 @@ def test_slice_exceptions_str():
|
|||
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1))
|
||||
assert "Indices are empty, generated tensor would be empty." in str(info.value)
|
||||
|
|
Loading…
Reference in New Issue