This commit is contained in:
hesham 2020-06-16 22:46:57 -04:00
parent e2012a1de9
commit f2462bb00d
15 changed files with 560 additions and 25 deletions

View File

@ -38,6 +38,7 @@
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
@ -369,7 +370,7 @@ void bindTensorOps2(py::module *m) {
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
.def(py::init<std::shared_ptr<Tensor>>());
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "")
(void)py::class_<SliceOp, TensorOp, std::shared_ptr<SliceOp>>(*m, "SliceOp", "Tensor Slice operation.")
.def(py::init<bool>())
.def(py::init([](const py::list &py_list) {
std::vector<dsize_t> c_list;
@ -400,6 +401,19 @@ void bindTensorOps2(py::module *m) {
return std::make_shared<SliceOp>(c_slice);
}));
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
.value("EQ", RelationalOp::kEqual)
.value("NE", RelationalOp::kNotEqual)
.value("LT", RelationalOp::kLess)
.value("LE", RelationalOp::kLessEqual)
.value("GT", RelationalOp::kGreater)
.value("GE", RelationalOp::kGreaterEqual)
.export_values();
(void)py::class_<MaskOp, TensorOp, std::shared_ptr<MaskOp>>(*m, "MaskOp",
"Tensor operation mask using relational comparator")
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
*m, "RandomRotationOp",
"Tensor operation to apply RandomRotation."

View File

@ -699,7 +699,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
Status Tensor::GetItemAt(std::string_view *o, const std::vector<dsize_t> &index) const {
RETURN_UNEXPECTED_IF_NULL(data_);
RETURN_UNEXPECTED_IF_NULL(o);
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not DE_STRING");
CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string");
uchar *start = nullptr;
offset_t length = 0;
@ -932,17 +932,17 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
dsize_t out_index = 0;
dsize_t dim_length = shape_[0];
dsize_t type_size = type_.SizeInBytes();
dsize_t src_start = handleNeg(indices[0], dim_length);
dsize_t src_start = HandleNeg(indices[0], dim_length);
uchar *dst_addr = (*out)->data_;
dsize_t count = 1;
for (dsize_t i = 0; i < indices.size(); i++) {
dsize_t cur_index = handleNeg(indices[i], dim_length);
dsize_t cur_index = HandleNeg(indices[i], dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")");
if (i < indices.size() - 1) {
dsize_t next_index = handleNeg(indices[i + 1], dim_length);
dsize_t next_index = HandleNeg(indices[i + 1], dim_length);
if (next_index == cur_index + 1) {
count++;
continue;
@ -951,7 +951,7 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size);
out_index += count;
if (i < indices.size() - 1) {
src_start = handleNeg(indices[i + 1], dim_length); // next index
src_start = HandleNeg(indices[i + 1], dim_length); // next index
}
count = 1;
}
@ -961,7 +961,7 @@ Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize
dsize_t dim_length = shape_[0];
std::vector<std::string> strings;
for (dsize_t index : indices) {
dsize_t cur_index = handleNeg(index, dim_length);
dsize_t cur_index = HandleNeg(index, dim_length);
CHECK_FAIL_RETURN_UNEXPECTED(
cur_index >= 0 && cur_index < dim_length,
"Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")");

View File

@ -348,7 +348,7 @@ class Tensor {
}
// Handle negative indices.
static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; }
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
// Based on the type of tensor, SliceNumeric or SliceString will be called

View File

@ -1,9 +1,10 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(kernels-data OBJECT
data_utils.cc
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc)
data_utils.cc
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc
mask_op.cc)

View File

@ -120,7 +120,7 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
std::shared_ptr<Tensor> fill_output;
op->Compute(fill_value, &fill_output);
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
@ -344,6 +344,8 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
return PadEndString(src, dst, pad_shape, "");
}
}
CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
"Source and pad_value tensors are not of the same type.");
if (pad_val->type().IsNumeric()) {
float val = 0;
RETURN_IF_NOT_OK(pad_val->GetItemAt<float>(&val, {}));
@ -454,5 +456,102 @@ Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::s
}
return Status::OK();
}
template <typename T>
Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
T value;
RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
auto in_itr = input->begin<T>();
auto out_itr = output->begin<bool>();
for (; in_itr != input->end<T>(); in_itr++, out_itr++) {
switch (op) {
case RelationalOp::kEqual:
*out_itr = (*in_itr == value);
break;
case RelationalOp::kNotEqual:
*out_itr = (*in_itr != value);
break;
case RelationalOp::kGreater:
*out_itr = (*in_itr > value);
break;
case RelationalOp::kGreaterEqual:
*out_itr = (*in_itr >= value);
break;
case RelationalOp::kLess:
*out_itr = (*in_itr < value);
break;
case RelationalOp::kLessEqual:
*out_itr = (*in_itr <= value);
break;
default:
RETURN_STATUS_UNEXPECTED("Unknown relational operator.");
}
}
return Status::OK();
}
Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
RelationalOp op) {
CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
"Cannot convert constant value to the type of the input tensor.");
CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar");
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL)));
std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
std::shared_ptr<Tensor> casted_value;
if (input->type().IsNumeric()) {
RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
} else {
casted_value = value;
}
switch (input->type().value()) {
case DataType::DE_BOOL:
RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
break;
case DataType::DE_INT8:
RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
break;
case DataType::DE_UINT8:
RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
break;
case DataType::DE_UINT16:
RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
break;
case DataType::DE_INT16:
RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
break;
case DataType::DE_UINT32:
RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
break;
case DataType::DE_INT32:
RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
break;
case DataType::DE_UINT64:
RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
break;
case DataType::DE_INT64:
RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
break;
case DataType::DE_FLOAT16:
RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
break;
case DataType::DE_FLOAT32:
RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
break;
case DataType::DE_FLOAT64:
RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
break;
case DataType::DE_STRING:
RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
break;
case DataType::DE_UNKNOWN:
RETURN_STATUS_UNEXPECTED("Unsupported input type.");
break;
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -119,6 +119,35 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
const std::string &pad_value);
enum class RelationalOp {
kEqual = 0, // ==
kNotEqual, // !=
kLess, // <
kLessEqual, // <=
kGreater, // >
kGreaterEqual, // >=
};
/// Helper method that masks the input tensor
/// @tparam T type of the tensor
/// @param input[in] input tensor
/// @param output[out] output tensor
/// @param value_tensor[in] scalar tensor value to compared with
/// @param op[in] RelationalOp enum
/// @return Status ok/error
template <typename T>
Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
const std::shared_ptr<Tensor> &value_tensor, RelationalOp op);
/// Mask the input tensor
/// @param input[in] input tensor
/// @param output[out] output tensor
/// @param value[in] scalar tensor value to compared with
/// @param op[in] RelationalOp enum
/// @return Status ok/error
Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
RelationalOp op);
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/mask_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
Status MaskOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
std::shared_ptr<Tensor> temp_output;
CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric.");
RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_));
// cast the output to the the required type. Skip casting if type_ is bool.
if (type_ != DataType::DE_BOOL) {
RETURN_IF_NOT_OK(cast_->Compute(temp_output, output));
} else {
*output = temp_output;
}
return Status::OK();
}
Status MaskOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = type_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_MASK_OP_H_
#define DATASET_KERNELS_DATA_MASK_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/data_utils.h"
namespace mindspore {
namespace dataset {
class MaskOp : public TensorOp {
public:
MaskOp(RelationalOp op, std::shared_ptr<Tensor> value, DataType type = DataType(DataType::DE_BOOL))
: op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {}
~MaskOp() override = default;
void Print(std::ostream &out) const override { out << "MaskOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private:
RelationalOp op_;
std::shared_ptr<Tensor> value_;
DataType type_;
std::unique_ptr<TypeCastOp> cast_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_DATA_MASK_OP_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,7 +16,6 @@
#include "dataset/kernels/data/slice_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {

View File

@ -36,8 +36,8 @@ class Slice {
std::vector<dsize_t> Indices(dsize_t length) {
std::vector<dsize_t> indices;
dsize_t index = std::min(Tensor::handleNeg(start_, length), length);
dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length);
dsize_t index = std::min(Tensor::HandleNeg(start_, length), length);
dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length);
if (step_ > 0) {
for (; index < end_index; index += step_) {
indices.push_back(index);
@ -80,4 +80,4 @@ class SliceOp : public TensorOp {
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_
#endif // DATASET_KERNELS_DATA_SLICE_OP_H_

View File

@ -15,10 +15,14 @@
"""
This module c_transforms provides common operations, including OneHotOp and TypeCast.
"""
import numpy as np
from enum import IntEnum
import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op
import numpy as np
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op
from ..core.datatypes import mstype_to_detype
@ -48,7 +52,6 @@ class Fill(cde.FillOp):
@check_fill_value
def __init__(self, fill_value):
print(fill_value)
super().__init__(cde.Tensor(np.array(fill_value)))
@ -108,3 +111,50 @@ class Slice(cde.SliceOp):
elif dim0 is Ellipsis:
dim0 = True
super().__init__(dim0)
class Relational(IntEnum):
EQ = 0
NE = 1
GT = 2
GE = 3
LT = 4
LE = 5
DE_C_RELATIONAL = {Relational.EQ: cde.RelationalOp.EQ,
Relational.NE: cde.RelationalOp.NE,
Relational.GT: cde.RelationalOp.GT,
Relational.GE: cde.RelationalOp.GE,
Relational.LT: cde.RelationalOp.LT,
Relational.LE: cde.RelationalOp.LE}
class Mask(cde.MaskOp):
"""
Mask content of the input tensor with the given predicate.
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
Args:
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
constant (python types (str, int, float, or bool): constant to be compared to.
Constant will be casted to the type of the input tensor
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
Examples:
>>> # Data before
>>> # | col1 |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------+
>>> data = data.map(operations=Mask(Relational.EQ, 2))
>>> # Data after
>>> # | col1 |
>>> # +--------------------+
>>> # | [False,True,False] |
>>> # +--------------------+
"""
@check_mask_op
def __init__(self, operator, constant, dtype=mstype.bool_):
dtype = mstype_to_detype(dtype)
constant = cde.Tensor(np.array(constant))
super().__init__(DE_C_RELATIONAL[operator], constant, dtype)

View File

@ -213,3 +213,40 @@ def check_slice_op(method):
return method(self, *args)
return new_method
def check_mask_op(method):
"""Wrapper method to check the parameters of slice."""
@wraps(method)
def new_method(self, *args, **kwargs):
operator, constant, dtype = (list(args) + 3 * [None])[:3]
if "operator" in kwargs:
operator = kwargs.get("operator")
if "constant" in kwargs:
constant = kwargs.get("constant")
if "dtype" in kwargs:
dtype = kwargs.get("dtype")
if operator is None:
raise ValueError("operator is not provided.")
if constant is None:
raise ValueError("constant is not provided.")
from .c_transforms import Relational
if not isinstance(operator, Relational):
raise TypeError("operator is not a Relational operator enum.")
if not isinstance(constant, (str, float, bool, int)):
raise TypeError("constant must be either a primitive python str, float, bool, or int")
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
kwargs["operator"] = operator
kwargs["constant"] = constant
kwargs["dtype"] = dtype
return method(self, **kwargs)
return new_method

View File

@ -0,0 +1,63 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "securec.h"
#include "dataset/core/tensor.h"
#include "dataset/core/cv_tensor.h"
#include "dataset/core/data_type.h"
#include "dataset/util/de_error.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/data_utils.h"
using namespace mindspore::dataset;
namespace py = pybind11;
class MindDataTestMaskOp : public UT::Common {
public:
MindDataTestMaskOp() {}
void SetUp() { GlobalInit(); }
};
TEST_F(MindDataTestMaskOp, Basics) {
std::shared_ptr<Tensor> t;
Tensor::CreateTensor(&t, std::vector<uint32_t>({1, 2, 3, 4, 5, 6}));
std::shared_ptr<Tensor> v;
Tensor::CreateTensor(&v, std::vector<uint32_t>({3}), TensorShape::CreateScalar());
std::shared_ptr<MaskOp> op = std::make_shared<MaskOp>(RelationalOp::kEqual, v, DataType(DataType::DE_UINT16));
std::shared_ptr<Tensor> out;
ASSERT_TRUE(op->Compute(t, &out).IsOk());
op = std::make_shared<MaskOp>(RelationalOp::kNotEqual, v, DataType(DataType::DE_UINT16));
ASSERT_TRUE(op->Compute(t, &out).IsOk());
op = std::make_shared<MaskOp>(RelationalOp::kLessEqual, v, DataType(DataType::DE_UINT16));
ASSERT_TRUE(op->Compute(t, &out).IsOk());
op = std::make_shared<MaskOp>(RelationalOp::kLess, v, DataType(DataType::DE_UINT16));
ASSERT_TRUE(op->Compute(t, &out).IsOk());
op = std::make_shared<MaskOp>(RelationalOp::kGreaterEqual, v, DataType(DataType::DE_UINT16));
ASSERT_TRUE(op->Compute(t, &out).IsOk());
op = std::make_shared<MaskOp>(RelationalOp::kGreater, v, DataType(DataType::DE_UINT16));
ASSERT_TRUE(op->Compute(t, &out).IsOk());
}

View File

@ -0,0 +1,132 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Mask op in DE
"""
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
mstype_to_np_type = {
mstype.bool_: np.bool,
mstype.int8: np.int8,
mstype.uint8: np.uint8,
mstype.int16: np.int16,
mstype.uint16: np.uint16,
mstype.int32: np.int32,
mstype.uint32: np.uint32,
mstype.int64: np.int64,
mstype.uint64: np.uint64,
mstype.float16: np.float16,
mstype.float32: np.float32,
mstype.float64: np.float64,
mstype.string: np.str
}
def mask_compare(array, op, constant, dtype=mstype.bool_):
data = ds.NumpySlicesDataset([array])
array = np.array(array)
data = data.map(operations=ops.Mask(op, constant, dtype))
for d in data:
if op == ops.Relational.EQ:
array = array == np.array(constant, dtype=array.dtype)
elif op == ops.Relational.NE:
array = array != np.array(constant, dtype=array.dtype)
elif op == ops.Relational.GT:
array = array > np.array(constant, dtype=array.dtype)
elif op == ops.Relational.GE:
array = array >= np.array(constant, dtype=array.dtype)
elif op == ops.Relational.LT:
array = array < np.array(constant, dtype=array.dtype)
elif op == ops.Relational.LE:
array = array <= np.array(constant, dtype=array.dtype)
array = array.astype(dtype=mstype_to_np_type[dtype])
np.testing.assert_array_equal(array, d[0])
def test_int_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
def test_float_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.EQ, 3, k)
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.NE, 3, k)
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LT, 3, k)
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.LE, 3, k)
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GT, 3, k)
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
def test_float_comparison2():
for k in mstype_to_np_type:
if k == mstype.string:
continue
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, 3.5, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.NE, 3.5, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.LT, 3.5, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.LE, 3.5, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.GT, 3.5, k)
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
def test_string_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.EQ, "3.", k)
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.NE, "3.", k)
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LT, "3.", k)
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.LE, "3.", k)
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GT, "3.", k)
mask_compare(["1.5", "2.5", "3.", "4.5", "5.5"], ops.Relational.GE, "3.", k)
def test_mask_exceptions_str():
with pytest.raises(RuntimeError) as info:
mask_compare([1, 2, 3, 4, 5], ops.Relational.EQ, "3.5")
assert "Cannot convert constant value to the type of the input tensor." in str(info.value)
with pytest.raises(RuntimeError) as info:
mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, 3.5)
assert "Cannot convert constant value to the type of the input tensor." in str(info.value)
with pytest.raises(RuntimeError) as info:
mask_compare(["1", "2", "3", "4", "5"], ops.Relational.EQ, "3.5", mstype.string)
assert "Cannot generate a string mask. Type should be numeric." in str(info.value)
if __name__ == "__main__":
test_int_comparison()
test_float_comparison()
test_float_comparison2()
test_string_comparison()
test_mask_exceptions_str()

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing TypeCast op in DE
Testing Slice op in DE
"""
import numpy as np
import pytest
@ -109,6 +109,10 @@ def test_slice_exceptions():
slice_compare([1, 2, 3, 4, 5], slice(0))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], slice(3, 1, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
@ -182,6 +186,10 @@ def test_slice_exceptions_str():
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(3, 1, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)
with pytest.raises(RuntimeError) as info:
slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1))
assert "Indices are empty, generated tensor would be empty." in str(info.value)