dataset: Pybind change for data transforms

This commit is contained in:
Cathy Wong 2021-03-18 17:27:14 -04:00
parent 83b25e10e9
commit 78392e0d66
14 changed files with 93 additions and 62 deletions

View File

@ -20,15 +20,11 @@
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/core/tensor_helpers.h"
#include "minddata/dataset/kernels/data/concatenate_op.h"
#include "minddata/dataset/kernels/data/duplicate_op.h"
#include "minddata/dataset/kernels/data/fill_op.h"
#include "minddata/dataset/kernels/data/mask_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/data/pad_end_op.h"
#include "minddata/dataset/kernels/data/slice_op.h"
#include "minddata/dataset/kernels/data/to_float16_op.h"
#include "minddata/dataset/kernels/data/type_cast_op.h"
#include "minddata/dataset/kernels/data/unique_op.h"
namespace mindspore {
namespace dataset {
@ -38,15 +34,6 @@ PYBIND_REGISTER(ConcatenateOp, 1, ([](const py::module *m) {
.def(py::init<int8_t, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>());
}));
PYBIND_REGISTER(
DuplicateOp, 1, ([](const py::module *m) {
(void)py::class_<DuplicateOp, TensorOp, std::shared_ptr<DuplicateOp>>(*m, "DuplicateOp").def(py::init<>());
}));
PYBIND_REGISTER(UniqueOp, 1, ([](const py::module *m) {
(void)py::class_<UniqueOp, TensorOp, std::shared_ptr<UniqueOp>>(*m, "UniqueOp").def(py::init<>());
}));
PYBIND_REGISTER(
FillOp, 1, ([](const py::module *m) {
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(*m, "FillOp").def(py::init<std::shared_ptr<Tensor>>());
@ -57,11 +44,6 @@ PYBIND_REGISTER(MaskOp, 1, ([](const py::module *m) {
.def(py::init<RelationalOp, std::shared_ptr<Tensor>, DataType>());
}));
PYBIND_REGISTER(
OneHotOp, 1, ([](const py::module *m) {
(void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(*m, "OneHotOp").def(py::init<int32_t>());
}));
PYBIND_REGISTER(PadEndOp, 1, ([](const py::module *m) {
(void)py::class_<PadEndOp, TensorOp, std::shared_ptr<PadEndOp>>(*m, "PadEndOp")
.def(py::init<TensorShape, std::shared_ptr<Tensor>>());
@ -111,12 +93,6 @@ PYBIND_REGISTER(ToFloat16Op, 1, ([](const py::module *m) {
.def(py::init<>());
}));
PYBIND_REGISTER(TypeCastOp, 1, ([](const py::module *m) {
(void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(*m, "TypeCastOp")
.def(py::init<DataType>())
.def(py::init<std::string>());
}));
PYBIND_REGISTER(RelationalOp, 0, ([](const py::module *m) {
(void)py::enum_<RelationalOp>(*m, "RelationalOp", py::arithmetic())
.value("EQ", RelationalOp::kEqual)

View File

@ -64,6 +64,28 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
DuplicateOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::DuplicateOperation, TensorOperation, std::shared_ptr<transforms::DuplicateOperation>>(
*m, "DuplicateOperation")
.def(py::init([]() {
auto duplicate = std::make_shared<transforms::DuplicateOperation>();
THROW_IF_ERROR(duplicate->ValidateParams());
return duplicate;
}));
}));
PYBIND_REGISTER(
OneHotOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::OneHotOperation, TensorOperation, std::shared_ptr<transforms::OneHotOperation>>(
*m, "OneHotOperation")
.def(py::init([](int32_t num_classes) {
auto one_hot = std::make_shared<transforms::OneHotOperation>(num_classes);
THROW_IF_ERROR(one_hot->ValidateParams());
return one_hot;
}));
}));
PYBIND_REGISTER(RandomChoiceOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::RandomChoiceOperation, TensorOperation,
std::shared_ptr<transforms::RandomChoiceOperation>>(*m, "RandomChoiceOperation")
@ -87,5 +109,28 @@ PYBIND_REGISTER(RandomApplyOperation, 1, ([](const py::module *m) {
return random_apply;
}));
}));
PYBIND_REGISTER(
TypeCastOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::TypeCastOperation, TensorOperation, std::shared_ptr<transforms::TypeCastOperation>>(
*m, "TypeCastOperation")
.def(py::init([](std::string data_type) {
auto type_cast = std::make_shared<transforms::TypeCastOperation>(data_type);
THROW_IF_ERROR(type_cast->ValidateParams());
return type_cast;
}));
}));
PYBIND_REGISTER(
UniqueOperation, 1, ([](const py::module *m) {
(void)py::class_<transforms::UniqueOperation, TensorOperation, std::shared_ptr<transforms::UniqueOperation>>(
*m, "UniqueOperation")
.def(py::init([]() {
auto unique = std::make_shared<transforms::UniqueOperation>();
THROW_IF_ERROR(unique->ValidateParams());
return unique;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -64,7 +64,7 @@ std::shared_ptr<TensorOperation> Duplicate::Parse() { return std::make_shared<Du
// Constructor to OneHot
struct OneHot::Data {
explicit Data(int32_t num_classes) : num_classes_(num_classes) {}
float num_classes_;
int32_t num_classes_;
};
OneHot::OneHot(int32_t num_classes) : data_(std::make_shared<Data>(num_classes)) {}

View File

@ -351,7 +351,7 @@ class SentencePieceTokenizer final : public TensorTransform {
/// \param[in] vocab a SentencePieceVocab object.
/// \param[in] out_type The type of output.
SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab,
mindspore::dataset::SPieceTokenizerOutType out_typee);
mindspore::dataset::SPieceTokenizerOutType out_type);
/// \brief Constructor.
/// \param[in] vocab_path vocab model file path.
@ -398,14 +398,14 @@ class SlidingWindow final : public TensorTransform {
};
/// \brief Tensor operation to convert every element of a string tensor to a number.
/// Strings are casted according to the rules specified in the following links:
/// Strings are cast according to the rules specified in the following links:
/// https://en.cppreference.com/w/cpp/string/basic_string/stof,
/// https://en.cppreference.com/w/cpp/string/basic_string/stoul,
/// except that any strings which represent negative numbers cannot be cast to an unsigned integer type.
class ToNumber final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] data_type of the tensor to be casted to. Must be a numeric type.
/// \param[in] data_type of the tensor to be cast to. Must be a numeric type.
explicit ToNumber(const std::string &data_type) : ToNumber(StringToChar(data_type)) {}
explicit ToNumber(const std::vector<char> &data_type);

View File

@ -38,11 +38,5 @@ Status OneHotOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector
return Status(StatusCode::kMDUnexpectedError, "OneHot: invalid input shape.");
}
Status OneHotOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_classes"] = num_classes_;
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -37,8 +37,6 @@ class OneHotOp : public TensorOp {
std::string Name() const override { return kOneHotOp; }
Status to_json(nlohmann::json *out_json) override;
private:
int num_classes_;
};

View File

@ -34,11 +34,5 @@ Status TypeCastOp::OutputType(const std::vector<DataType> &inputs, std::vector<D
return Status::OK();
}
Status TypeCastOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["data_type"] = type_.ToString();
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -43,8 +43,6 @@ class TypeCastOp : public TensorOp {
std::string Name() const override { return kTypeCastOp; }
Status to_json(nlohmann::json *out_json) override;
private:
DataType type_;
};

View File

@ -78,6 +78,13 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
Status OneHotOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["num_classes"] = num_classes_;
*out_json = args;
return Status::OK();
}
// PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {
#ifdef ENABLE_PYTHON
@ -149,6 +156,13 @@ Status TypeCastOperation::ValidateParams() {
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
Status TypeCastOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["data_type"] = data_type_;
*out_json = args;
return Status::OK();
}
#ifndef ENABLE_ANDROID
// UniqueOperation
Status UniqueOperation::ValidateParams() { return Status::OK(); }

View File

@ -81,8 +81,10 @@ class OneHotOperation : public TensorOperation {
std::string Name() const override { return kOneHotOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float num_classes_;
int32_t num_classes_;
};
class PreBuiltOperation : public TensorOperation {
@ -147,6 +149,8 @@ class TypeCastOperation : public TensorOperation {
std::string Name() const override { return kTypeCastOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::string data_type_;
};

View File

@ -362,8 +362,7 @@ def construct_tensor_ops(operations):
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
elif hasattr(op_module_trans, op_name):
op_class = getattr(op_module_trans, op_name, None)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")

View File

@ -387,18 +387,18 @@ class ToNumber(TextTensorOperation):
"""
Tensor operation to convert every element of a string tensor to a number.
Strings are casted according to the rules specified in the following links:
Strings are cast according to the rules specified in the following links:
https://en.cppreference.com/w/cpp/string/basic_string/stof,
https://en.cppreference.com/w/cpp/string/basic_string/stoul,
except that any strings which represent negative numbers cannot be cast to an
unsigned integer type.
Args:
data_type (mindspore.dtype): mindspore.dtype to be casted to. Must be
data_type (mindspore.dtype): mindspore.dtype to be cast to. Must be
a numeric type.
Raises:
RuntimeError: If strings are invalid to cast, or are out of range after being casted.
RuntimeError: If strings are invalid to cast, or are out of range after being cast.
Examples:
>>> import mindspore.common.dtype as mstype

View File

@ -21,7 +21,7 @@ import numpy as np
import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_option, check_slice_op, \
from .validators import check_num_classes, check_ms_type, check_fill_value, check_slice_option, check_slice_op, \
check_mask_op, check_pad_end, check_concat_type, check_random_transform_ops
from ..core.datatypes import mstype_to_detype
@ -52,7 +52,7 @@ class TensorOperation:
raise NotImplementedError("TensorOperation has to implement parse() method.")
class OneHot(cde.OneHotOp):
class OneHot(TensorOperation):
"""
Tensor operation to apply one hot encoding.
@ -72,7 +72,9 @@ class OneHot(cde.OneHotOp):
@check_num_classes
def __init__(self, num_classes):
self.num_classes = num_classes
super().__init__(num_classes)
def parse(self):
return cde.OneHotOperation(self.num_classes)
class Fill(cde.FillOp):
@ -102,7 +104,7 @@ class Fill(cde.FillOp):
super().__init__(cde.Tensor(np.array(fill_value)))
class TypeCast(cde.TypeCastOp):
class TypeCast(TensorOperation):
"""
Tensor operation to cast to a given MindSpore data type.
@ -123,11 +125,13 @@ class TypeCast(cde.TypeCastOp):
>>> dataset = dataset.map(operations=type_cast_op)
"""
@check_de_type
@check_ms_type
def __init__(self, data_type):
data_type = mstype_to_detype(data_type)
self.data_type = str(data_type)
super().__init__(data_type)
def parse(self):
return cde.TypeCastOperation(self.data_type)
class _SliceOption(cde.SliceOption):
@ -314,7 +318,7 @@ class Concatenate(cde.ConcatenateOp):
super().__init__(axis, prepend, append)
class Duplicate(cde.DuplicateOp):
class Duplicate(TensorOperation):
"""
Duplicate the input tensor to output, only support transform one column each time.
@ -337,8 +341,11 @@ class Duplicate(cde.DuplicateOp):
>>> # +---------+---------+
"""
def parse(self):
return cde.DuplicateOperation()
class Unique(cde.UniqueOp):
class Unique(TensorOperation):
"""
Perform the unique operation on the input tensor, only support transform one column each time.
@ -373,9 +380,11 @@ class Unique(cde.UniqueOp):
>>> # +---------+-----------------+---------+
"""
def parse(self):
return cde.UniqueOperation()
class Compose():
class Compose(TensorOperation):
"""
Compose a list of transforms into a single transform.
@ -401,7 +410,7 @@ class Compose():
return cde.ComposeOperation(operations)
class RandomApply():
class RandomApply(TensorOperation):
"""
Randomly perform a series of transforms with a given probability.
@ -429,7 +438,7 @@ class RandomApply():
return cde.RandomApplyOperation(self.prob, operations)
class RandomChoice():
class RandomChoice(TensorOperation):
"""
Randomly select one transform from a list of transforms to perform operation.

View File

@ -87,7 +87,7 @@ def check_num_classes(method):
return new_method
def check_de_type(method):
def check_ms_type(method):
"""Wrapper method to check the parameters of data type."""
@wraps(method)