!3010 mindspore.dataset c_transform support for RandomApply RandomChoice, Compose and RandomSelectSubpolicy

Merge pull request !3010 from ZiruiWu/random_tensor_ops
This commit is contained in:
mindspore-ci-bot 2020-07-16 03:21:22 +08:00 committed by Gitee
commit 60927ef130
22 changed files with 1029 additions and 11 deletions

View File

@ -42,8 +42,6 @@
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_category.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "pybind11/stl.h"
#include "utils/log_adapter.h"

View File

@ -16,6 +16,7 @@
#include <exception>
#include "minddata/dataset/api/de_pipeline.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
@ -35,9 +36,9 @@
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/kernels/data/concatenate_op.h"
#include "minddata/dataset/kernels/data/duplicate_op.h"
#include "minddata/dataset/kernels/data/fill_op.h"
@ -61,11 +62,12 @@
#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h"
#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_resize_op.h"
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
#include "minddata/dataset/kernels/image/random_rotation_op.h"
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
#include "minddata/dataset/kernels/image/rescale_op.h"
@ -74,6 +76,9 @@
#include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
#include "minddata/dataset/kernels/no_op.h"
#include "minddata/dataset/kernels/py_func_op.h"
#include "minddata/dataset/kernels/random_apply_op.h"
#include "minddata/dataset/kernels/random_choice_op.h"
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
#include "minddata/dataset/text/kernels/lookup_op.h"
#include "minddata/dataset/text/kernels/ngram_op.h"
@ -88,6 +93,7 @@
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"
@ -113,6 +119,24 @@ namespace dataset {
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
} while (false)
Status PyListToTensorOps(const py::list &py_ops, std::vector<std::shared_ptr<TensorOp>> *ops) {
RETURN_UNEXPECTED_IF_NULL(ops);
for (auto op : py_ops) {
if (py::isinstance<TensorOp>(op)) {
ops->emplace_back(op.cast<std::shared_ptr<TensorOp>>());
} else if (py::isinstance<py::function>(op)) {
ops->emplace_back(std::make_shared<PyFuncOp>(op.cast<py::function>()));
} else {
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc.");
}
}
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
for (auto const &op : *ops) {
RETURN_UNEXPECTED_IF_NULL(op);
}
return Status::OK();
}
void bindDEPipeline(py::module *m) {
(void)py::class_<DEPipeline>(*m, "DEPipeline")
.def(py::init<>())
@ -623,7 +647,7 @@ void bindTokenizerOps(py::module *m) {
WordIdType default_id = vocab->Lookup(word);
if (default_id == Vocab::kNoTokenExists) {
THROW_IF_ERROR(
Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab."));
Status(StatusCode::kUnexpectedError, "default unknown token: " + word + " doesn't exist in vocab."));
}
return std::make_shared<LookupOp>(vocab, default_id);
}));
@ -868,6 +892,58 @@ void bindGraphData(py::module *m) {
});
}
void bindRandomTransformTensorOps(py::module *m) {
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<ComposeOp>(t_ops);
}));
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
.def(py::init([](const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
return std::make_shared<RandomChoiceOp>(t_ops);
}));
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
.def(py::init([](double prob, const py::list &ops) {
std::vector<std::shared_ptr<TensorOp>> t_ops;
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
}
return std::make_shared<RandomApplyOp>(prob, t_ops);
}));
(void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
*m, "RandomSelectSubpolicyOp")
.def(py::init([](const py::list &py_policy) {
std::vector<Subpolicy> cpp_policy;
for (auto &py_sub : py_policy) {
cpp_policy.push_back({});
for (auto handle : py_sub.cast<py::list>()) {
py::tuple tp = handle.cast<py::tuple>();
if (tp.is_none() || tp.size() != 2) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
}
std::shared_ptr<TensorOp> t_op;
if (py::isinstance<TensorOp>(tp[0])) {
t_op = (tp[0]).cast<std::shared_ptr<TensorOp>>();
} else if (py::isinstance<py::function>(tp[0])) {
t_op = std::make_shared<PyFuncOp>((tp[0]).cast<py::function>());
} else {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "op is neither a tensorOp nor a pyfunc."));
}
double prob = (tp[1]).cast<py::float_>();
if (prob < 0 || prob > 1) {
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
}
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
}
}
return std::make_shared<RandomSelectSubpolicyOp>(cpp_policy);
}));
}
// This is where we externalize the C logic as python modules
PYBIND11_MODULE(_c_dataengine, m) {
m.doc() = "pybind11 for _c_dataengine";
@ -949,6 +1025,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindVocabObjects(&m);
bindGraphData(&m);
bindDependIcuTokenizerOps(&m);
bindRandomTransformTensorOps(&m);
}
} // namespace dataset
} // namespace mindspore

View File

@ -4,11 +4,16 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(kernels OBJECT
compose_op.cc
random_apply_op.cc
random_choice_op.cc
py_func_op.cc
tensor_op.cc)
target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS})
else()
add_library(kernels OBJECT
compose_op.cc
random_apply_op.cc
random_choice_op.cc
tensor_op.cc)
endif()

View File

@ -0,0 +1,66 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/compose_op.h"
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/py_func_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status ComposeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
std::vector<TensorShape> in_shapes = inputs;
for (auto &op : ops_) {
RETURN_IF_NOT_OK(op->OutputShape(in_shapes, outputs));
in_shapes = std::move(outputs); // outputs become empty after move
}
outputs = std::move(in_shapes);
return Status::OK();
}
Status ComposeOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
std::vector<DataType> in_types = inputs;
for (auto &op : ops_) {
RETURN_IF_NOT_OK(op->OutputType(in_types, outputs));
in_types = std::move(outputs); // outputs become empty after move
}
outputs = std::move(in_types);
return Status::OK();
}
Status ComposeOp::Compute(const TensorRow &inputs, TensorRow *outputs) {
IO_CHECK_VECTOR(inputs, outputs);
TensorRow in_rows = inputs;
for (auto &op : ops_) {
RETURN_IF_NOT_OK(op->Compute(in_rows, outputs));
in_rows = std::move(*outputs); // after move, *outputs become empty
}
(*outputs) = std::move(in_rows);
return Status::OK();
}
ComposeOp::ComposeOp(const std::vector<std::shared_ptr<TensorOp>> &ops) : ops_(ops) {
if (ops_.empty()) {
MS_LOG(ERROR) << "op_list is empty this might lead to Segmentation Fault.";
} else if (ops_.size() == 1) {
MS_LOG(WARNING) << "op_list has only 1 op. Compose is probably not needed.";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_COMPOSE_OP_
#define DATASET_KERNELS_COMPOSE_OP_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class ComposeOp : public TensorOp {
public:
/// constructor
/// \param[in] ops list of TensorOps to compose into 1 TensorOp
explicit ComposeOp(const std::vector<std::shared_ptr<TensorOp>> &ops);
/// default destructor
~ComposeOp() override = default;
/// return the number of inputs the first tensorOp in compose takes
/// \return number of input tensors
uint32_t NumInput() override { return ops_.front()->NumInput(); }
/// return the number of outputs the last tensorOp in compose produces
/// \return number of output tensors
uint32_t NumOutput() override { return ops_.back()->NumOutput(); }
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
/// \param[in] input
/// \param[out] output
/// \return Status code
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kComposeOp; }
private:
std::vector<std::shared_ptr<TensorOp>> ops_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_COMPOSE_OP_

View File

@ -19,6 +19,7 @@ add_library(kernels-image OBJECT
bounding_box_augment_op.cc
random_resize_op.cc
random_rotation_op.cc
random_select_subpolicy_op.cc
random_vertical_flip_op.cc
random_vertical_flip_with_bbox_op.cc
rescale_op.cc

View File

@ -0,0 +1,96 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
TensorRow in_row = input;
size_t rand_num = rand_int_(gen_);
CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(), "invalid rand_num:" + std::to_string(rand_num));
for (auto &sub : policy_[rand_num]) {
if (rand_double_(gen_) <= sub.second) {
RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
in_row = std::move(*output);
}
}
*output = std::move(in_row);
return Status::OK();
}
uint32_t RandomSelectSubpolicyOp::NumInput() {
uint32_t num_in = policy_.front().front().first->NumInput();
for (auto &sub : policy_) {
for (auto p : sub) {
if (num_in != p.first->NumInput()) {
MS_LOG(WARNING) << "Unable to determine numInput.";
return 0;
}
}
}
return num_in;
}
uint32_t RandomSelectSubpolicyOp::NumOutput() {
uint32_t num_out = policy_.front().front().first->NumOutput();
for (auto &sub : policy_) {
for (auto p : sub) {
if (num_out != p.first->NumOutput()) {
MS_LOG(WARNING) << "Unable to determine numInput.";
return 0;
}
}
}
return num_out;
}
Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
outputs.clear();
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
return Status::OK();
}
Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
for (auto &sub : policy_) {
for (auto p : sub) {
std::vector<DataType> tmp_types;
RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
if (outputs != tmp_types) {
outputs.clear();
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
return Status::OK();
}
}
}
return Status::OK();
}
RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
: gen_(GetSeed()), policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {
if (policy_.empty()) {
MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty.";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_
#define DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
using Subpolicy = std::vector<std::pair<std::shared_ptr<TensorOp>, double>>;
class RandomSelectSubpolicyOp : public TensorOp {
public:
/// constructor
/// \param[in] policy policy to choose subpolicy from
explicit RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy);
/// destructor
~RandomSelectSubpolicyOp() override = default;
/// return number of input tensors
/// \return number of inputs if all ops in policy have the same NumInput, otherwise return 0
uint32_t NumInput() override;
/// return number of output tensors
/// \return number of outputs if all ops in policy have the same NumOutput, otherwise return 0
uint32_t NumOutput() override;
/// return unknown shapes
/// \param[in] inputs
/// \param[out] outputs
/// \return Status Code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
/// return output type if all ops in policy return the same type, otherwise return unknown type
/// \param[in] inputs
/// \param[out] outputs
/// \return Status Code
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
/// \param[in] input
/// \param[out] output
/// \return Status code
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kRandomSelectSubpolicyOp; }
private:
std::vector<Subpolicy> policy_;
std::mt19937 gen_; // mersenne_twister_engine
std::uniform_int_distribution<size_t> rand_int_;
std::uniform_real_distribution<double> rand_double_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_

View File

@ -0,0 +1,68 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/random_apply_op.h"
#include <memory>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
uint32_t RandomApplyOp::NumOutput() {
if (compose_->NumOutput() != NumInput()) {
MS_LOG(WARNING) << "NumOutput!=NumInput (randomApply would randomly affect number of outputs).";
return 0;
}
return compose_->NumOutput();
}
Status RandomApplyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(compose_->OutputShape(inputs, outputs));
// randomApply either runs all ops or do nothing. If the two methods don't give the same result. return unknown shape.
if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output
outputs.clear();
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
}
return Status::OK();
}
Status RandomApplyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(compose_->OutputType(inputs, outputs));
if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output
outputs.clear();
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
}
return Status::OK();
}
Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) {
if (rand_double_(gen_) <= prob_) {
RETURN_IF_NOT_OK(compose_->Compute(input, output));
} else {
IO_CHECK_VECTOR(input, output);
*output = input; // copy over the tensors
}
return Status::OK();
}
RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops)
: prob_(prob), gen_(GetSeed()), rand_double_(0, 1) {
compose_ = std::make_unique<ComposeOp>(ops);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_RANDOM_APPLY_OP_
#define DATASET_KERNELS_RANDOM_APPLY_OP_
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
class RandomApplyOp : public TensorOp {
public:
/// constructor
/// \param[in] prob probability whether the list of TensorOps will be applied
/// \param[in] ops the list of TensorOps to apply with prob likelihood
explicit RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops);
/// default destructor
~RandomApplyOp() = default;
/// return the number of inputs the first tensorOp in compose takes
/// \return number of input tensors
uint32_t NumInput() override { return compose_->NumInput(); }
/// return the number of outputs
/// \return number of output tensors
uint32_t NumOutput() override;
/// return output shape if randomApply won't affect the output shape, otherwise return unknown shape
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
/// return output type if randomApply won't affect the output type, otherwise return unknown type
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
/// \param[in] input
/// \param[out] output
/// \return Status code
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kRandomApplyOp; }
private:
double prob_;
std::shared_ptr<TensorOp> compose_;
std::mt19937 gen_; // mersenne_twister_engine
std::uniform_real_distribution<double> rand_double_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_RANDOM_APPLY_OP_

View File

@ -0,0 +1,97 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/random_choice_op.h"
#include <memory>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
uint32_t RandomChoiceOp::NumInput() {
uint32_t num_input = ops_.front()->NumInput();
for (auto &op : ops_) {
uint32_t cur_num = op->NumInput();
if (num_input != cur_num && cur_num > 0) {
MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't take the same number of input.";
return 0;
}
}
return num_input;
}
uint32_t RandomChoiceOp::NumOutput() {
uint32_t num_output = ops_.front()->NumOutput();
for (auto &op : ops_) {
uint32_t cur_num = op->NumOutput();
if (num_output != cur_num) {
MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't have the same number of input.";
return 0;
}
}
return num_output;
}
Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
for (auto &op : ops_) {
std::vector<TensorShape> out_shapes;
RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
if (outputs != out_shapes) {
MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
outputs.clear();
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
return Status::OK();
}
}
return Status::OK();
}
Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
for (auto &op : ops_) {
std::vector<DataType> out_types;
RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
if (outputs != out_types) {
MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
outputs.clear();
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
return Status::OK();
}
}
return Status::OK();
}
Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
size_t rand_num = rand_int_(gen_);
CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
return Status::OK();
}
RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
: ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) {
if (ops_.empty()) {
MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty.";
} else if (ops_.size() == 1) {
MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_RANDOM_CHOICE_OP_
#define DATASET_KERNELS_RANDOM_CHOICE_OP_
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/kernels/compose_op.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
class RandomChoiceOp : public TensorOp {
public:
/// constructor
/// \param[in] ops list of TensorOps to randomly choose 1 from
explicit RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops);
/// default destructor
~RandomChoiceOp() = default;
/// return the number of inputs. All op in ops_ should have the same number of inputs
/// \return number of input tensors
uint32_t NumInput() override;
/// return the number of outputs. All op in ops_ should have the same number of outputs
/// \return number of input tensors
uint32_t NumOutput() override;
/// return output shape if all ops in ops_ return the same shape, otherwise return unknown shape
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
/// return output type if all ops in ops_ return the same type, otherwise return unknown type
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
/// \param[in] input
/// \param[out] output
/// \return Status code
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kRandomChoiceOp; }
private:
std::vector<std::shared_ptr<TensorOp>> ops_;
std::mt19937 gen_; // mersenne_twister_engine
std::uniform_int_distribution<size_t> rand_int_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_RANDOM_CHOICE_OP_

View File

@ -129,6 +129,10 @@ constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp";
constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp";
constexpr char kRandomChoiceOp[] = "RandomChoiceOp";
constexpr char kRandomApplyOp[] = "RandomApplyOp";
constexpr char kComposeOp[] = "ComposeOp";
constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp";
// data
constexpr char kConcatenateOp[] = "kConcatenateOp";

View File

@ -19,6 +19,8 @@ import inspect
from multiprocessing import cpu_count
import os
import numpy as np
import mindspore._c_dataengine as cde
from ..engine import samplers
# POS_INT_MIN is used to limit values from starting from 0
@ -358,3 +360,9 @@ def check_gnn_list_or_ndarray(param, param_name):
if not param.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))
def check_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable python function"""
if not isinstance(param, cde.TensorOp) and not callable(param):
raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))

View File

@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \
check_pad_end, check_concat_type
check_pad_end, check_concat_type, check_random_transform_ops
from ..core.datatypes import mstype_to_detype
@ -82,7 +82,7 @@ class Slice(cde.SliceOp):
Maximum `n` number of arguments to slice a tensor of rank `n`.
One object in slices can be one of:
1. :py:obj:`int`: Slice this index only. Negative index is supported.
2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supdeported.
2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing.
5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing.
@ -232,3 +232,50 @@ class Duplicate(cde.DuplicateOp):
>>> # | [1,2,3] | [1,2,3] |
>>> # +---------+---------+
"""
class Compose(cde.ComposeOp):
"""
Compose a list of transforms into a single transform.
Args:
transforms (list): List of transformations to be applied.
Example:
>>> compose = Compose([vision.Decode(), vision.RandomCrop()])
>>> dataset = ds.map(operations=compose)
"""
@check_random_transform_ops
def __init__(self, op_list):
super().__init__(op_list)
class RandomApply(cde.RandomApplyOp):
"""
Randomly performs a series of transforms with a given probability.
Args:
transforms (list): List of transformations to be applied.
prob (float, optional): The probability to apply the transformation list (default=0.5)
Example:
>>> rand_apply = RandomApply([vision.RandomCrop()])
>>> dataset = ds.map(operations=rand_apply)
"""
@check_random_transform_ops
def __init__(self, op_list, prob=0.5):
super().__init__(prob, op_list)
class RandomChoice(cde.RandomChoiceOp):
"""
Randomly selects one transform from a list of transforms to perform operation.
Args:
transforms (list): List of transformations to be chosen from to apply.
Example:
>>> rand_choice = RandomChoice([vision.CenterCrop(), vision.RandomCrop()])
>>> dataset = ds.map(operations=rand_choice)
"""
@check_random_transform_ops
def __init__(self, op_list):
super().__init__(op_list)

View File

@ -18,7 +18,8 @@ from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
check_tensor_op
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
@ -180,3 +181,22 @@ def check_concat_type(method):
return method(self, *args, **kwargs)
return new_method
def check_random_transform_ops(method):
"""Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
@wraps(method)
def new_method(self, *args, **kwargs):
arg_list, _ = parse_user_args(method, *args, **kwargs)
type_check(arg_list[0], (list,), "op_list")
if not arg_list[0]:
raise ValueError("op_list can not be empty.")
for ind, op in enumerate(arg_list[0]):
check_tensor_op(op, "op_list[{0}]".format(ind))
if len(arg_list) == 2: # random apply takes an additional arg
type_check(arg_list[1], (float, int), "prob")
check_value(arg_list[1], (0, 1), "prob")
return method(self, *args, **kwargs)
return new_method

View File

@ -47,7 +47,7 @@ from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
FLOAT_MAX_INTEGER
check_random_select_subpolicy_op, FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
@ -712,3 +712,9 @@ class UniformAugment(cde.UniformAugOp):
self.operations = operations
self.num_ops = num_ops
super().__init__(operations, num_ops)
class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
@check_random_select_subpolicy_op
def __init__(self, policy):
super().__init__(policy)

View File

@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp
from .utils import Inter, Border
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, check_tensor_op
def check_crop_size(size):
@ -588,3 +588,26 @@ def check_compose_list(method):
return method(self, *args, **kwargs)
return new_method
def check_random_select_subpolicy_op(method):
"""Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
@wraps(method)
def new_method(self, *args, **kwargs):
[policy], _ = parse_user_args(method, *args, **kwargs)
type_check(policy, (list,), "policy")
if not policy:
raise ValueError("policy can not be empty.")
for sub_ind, sub in enumerate(policy):
type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
if not sub:
raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
for op_ind, tp in enumerate(sub):
check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
return method(self, *args, **kwargs)
return new_method

View File

@ -0,0 +1,50 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
import mindspore.dataset.transforms.py_transforms as py_ops
def test_compose():
ds.config.set_seed(0)
def test_config(arr, op_list):
try:
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
data = data.map(input_columns=["col"], operations=ops.Compose(op_list))
res = []
for i in data.create_dict_iterator():
res.append(i["col"].tolist())
return res
except (TypeError, ValueError) as e:
return str(e)
# test simple compose with only 1 op, this would generate a warning
assert test_config([[1, 0], [3, 4]], [ops.Fill(2)]) == [[2, 2], [2, 2]]
# test 1 column -> 2columns -> 1 -> 2 -> 1
assert test_config([[1, 0]], [ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()]) == [
[1, 0] * 4]
# test one python transform followed by a C transform. type after oneHot is float (mixed use-case)
assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]]
# test exceptions. compose, randomApply randomChoice use the same validator
assert "op_list[0] is not a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)])
# test empty op list
assert "op_list can not be empty." in test_config([1, 0], [])
if __name__ == "__main__":
test_compose()

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
def test_random_apply():
ds.config.set_seed(0)
def test_config(arr, op_list, prob=0.5):
try:
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
data = data.map(input_columns=["col"], operations=ops.RandomApply(op_list, prob))
res = []
for i in data.create_dict_iterator():
res.append(i["col"].tolist())
return res
except (TypeError, ValueError) as e:
return str(e)
res1 = test_config([[0, 1]], [ops.Duplicate(), ops.Concatenate()])
assert res1 in [[[0, 1]], [[0, 1, 0, 1]]]
# test single nested compose
assert test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate(), ops.Slice([0, 1, 2])])]) == [
[0, 1, 2]]
# test exception
assert "is not of type (<class 'list'>" in test_config([1, 0], ops.TypeCast(mstype.int32))
assert "Input prob is not within the required interval" in test_config([0, 1], [ops.Slice([0, 1])], 1.1)
assert "is not of type (<class 'float'>" in test_config([1, 0], [ops.TypeCast(mstype.int32)], None)
assert "op_list with value None is not of type (<class 'list'>" in test_config([1, 0], None)
if __name__ == "__main__":
test_random_apply()

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
def test_random_choice():
ds.config.set_seed(0)
def test_config(arr, op_list):
try:
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
data = data.map(input_columns=["col"], operations=ops.RandomChoice(op_list))
res = []
for i in data.create_dict_iterator():
res.append(i["col"].tolist())
return res
except (TypeError, ValueError) as e:
return str(e)
# test whether a op would be randomly chosen. In order to prevent random failure, both results need to be checked
res1 = test_config([[0, 1, 2]], [ops.PadEnd([4], 0), ops.Slice([0, 2])])
assert res1 in [[[0, 1, 2, 0]], [[0, 2]]]
# test nested structure
res2 = test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate()]),
ops.Compose([ops.Slice([0, 1]), ops.OneHot(2)])])
assert res2 in [[[[1, 0], [0, 1]]], [[0, 1, 2, 0, 1, 2]]]
# test random_choice where there is only 1 op
assert test_config([[4, 3], [2, 1]], [ops.Slice([0])]) == [[4], [2]]
if __name__ == "__main__":
test_random_choice()

View File

@ -0,0 +1,51 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
import mindspore.dataset.transforms.vision.c_transforms as visions
def test_random_select_subpolicy():
ds.config.set_seed(0)
def test_config(arr, policy):
try:
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
data = data.map(input_columns=["col"], operations=visions.RandomSelectSubpolicy(policy))
res = []
for i in data.create_dict_iterator():
res.append(i["col"].tolist())
return res
except (TypeError, ValueError) as e:
return str(e)
# 3 possible outcomes
policy1 = [[(ops.PadEnd([4], 0), 0.5), (ops.Compose([ops.Duplicate(), ops.Concatenate()]), 1)],
[(ops.Slice([0, 1]), 0.5), (ops.Duplicate(), 1), (ops.Concatenate(), 1)]]
res1 = test_config([[1, 2, 3]], policy1)
assert res1 in [[[1, 2, 1, 2]], [[1, 2, 3, 1, 2, 3]], [[1, 2, 3, 0, 1, 2, 3, 0]]]
# test exceptions
assert "policy can not be empty." in test_config([[1, 2, 3]], [])
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
assert "op of (op, prob) in policy[1][0] is not a c_transform op (TensorOp) nor a callable pyfunc" in test_config(
[[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])
if __name__ == "__main__":
test_random_select_subpolicy()