forked from mindspore-Ecosystem/mindspore
implemented cpp random choice, apply and compos
python part of random ops added random select sub policy validators added comments added, remaining issues addressed add more python test cases fix ci fix CI fix order of include files addr review cmts addr review cmts reorg file fix compile err address review cmts address review cmts
This commit is contained in:
parent
cda333f760
commit
d233c54139
|
@ -42,8 +42,6 @@
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
#include "minddata/mindrecord/include/shard_category.h"
|
#include "minddata/mindrecord/include/shard_category.h"
|
||||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||||
#include "minddata/mindrecord/include/shard_sample.h"
|
|
||||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include <exception>
|
#include <exception>
|
||||||
|
|
||||||
#include "minddata/dataset/api/de_pipeline.h"
|
#include "minddata/dataset/api/de_pipeline.h"
|
||||||
|
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||||
|
@ -35,9 +36,9 @@
|
||||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/graph.h"
|
#include "minddata/dataset/engine/gnn/graph.h"
|
||||||
#include "minddata/dataset/engine/jagged_connector.h"
|
#include "minddata/dataset/engine/jagged_connector.h"
|
||||||
|
#include "minddata/dataset/kernels/compose_op.h"
|
||||||
#include "minddata/dataset/kernels/data/concatenate_op.h"
|
#include "minddata/dataset/kernels/data/concatenate_op.h"
|
||||||
#include "minddata/dataset/kernels/data/duplicate_op.h"
|
#include "minddata/dataset/kernels/data/duplicate_op.h"
|
||||||
#include "minddata/dataset/kernels/data/fill_op.h"
|
#include "minddata/dataset/kernels/data/fill_op.h"
|
||||||
|
@ -61,11 +62,12 @@
|
||||||
#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
|
|
||||||
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_resize_op.h"
|
#include "minddata/dataset/kernels/image/random_resize_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
#include "minddata/dataset/kernels/image/random_rotation_op.h"
|
||||||
|
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
|
||||||
#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/rescale_op.h"
|
#include "minddata/dataset/kernels/image/rescale_op.h"
|
||||||
|
@ -74,6 +76,9 @@
|
||||||
#include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
|
#include "minddata/dataset/kernels/image/resize_with_bbox_op.h"
|
||||||
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
|
||||||
#include "minddata/dataset/kernels/no_op.h"
|
#include "minddata/dataset/kernels/no_op.h"
|
||||||
|
#include "minddata/dataset/kernels/py_func_op.h"
|
||||||
|
#include "minddata/dataset/kernels/random_apply_op.h"
|
||||||
|
#include "minddata/dataset/kernels/random_choice_op.h"
|
||||||
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h"
|
||||||
#include "minddata/dataset/text/kernels/lookup_op.h"
|
#include "minddata/dataset/text/kernels/lookup_op.h"
|
||||||
#include "minddata/dataset/text/kernels/ngram_op.h"
|
#include "minddata/dataset/text/kernels/ngram_op.h"
|
||||||
|
@ -88,6 +93,7 @@
|
||||||
#include "minddata/mindrecord/include/shard_sample.h"
|
#include "minddata/mindrecord/include/shard_sample.h"
|
||||||
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
||||||
#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h"
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
#include "pybind11/stl_bind.h"
|
#include "pybind11/stl_bind.h"
|
||||||
|
@ -113,6 +119,24 @@ namespace dataset {
|
||||||
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
|
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
|
Status PyListToTensorOps(const py::list &py_ops, std::vector<std::shared_ptr<TensorOp>> *ops) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(ops);
|
||||||
|
for (auto op : py_ops) {
|
||||||
|
if (py::isinstance<TensorOp>(op)) {
|
||||||
|
ops->emplace_back(op.cast<std::shared_ptr<TensorOp>>());
|
||||||
|
} else if (py::isinstance<py::function>(op)) {
|
||||||
|
ops->emplace_back(std::make_shared<PyFuncOp>(op.cast<py::function>()));
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("element is neither a TensorOp nor a pyfunc.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!ops->empty(), "TensorOp list is empty.");
|
||||||
|
for (auto const &op : *ops) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
void bindDEPipeline(py::module *m) {
|
void bindDEPipeline(py::module *m) {
|
||||||
(void)py::class_<DEPipeline>(*m, "DEPipeline")
|
(void)py::class_<DEPipeline>(*m, "DEPipeline")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
|
@ -868,6 +892,58 @@ void bindGraphData(py::module *m) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void bindRandomTransformTensorOps(py::module *m) {
|
||||||
|
(void)py::class_<ComposeOp, TensorOp, std::shared_ptr<ComposeOp>>(*m, "ComposeOp")
|
||||||
|
.def(py::init([](const py::list &ops) {
|
||||||
|
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||||
|
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||||
|
return std::make_shared<ComposeOp>(t_ops);
|
||||||
|
}));
|
||||||
|
(void)py::class_<RandomChoiceOp, TensorOp, std::shared_ptr<RandomChoiceOp>>(*m, "RandomChoiceOp")
|
||||||
|
.def(py::init([](const py::list &ops) {
|
||||||
|
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||||
|
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||||
|
return std::make_shared<RandomChoiceOp>(t_ops);
|
||||||
|
}));
|
||||||
|
(void)py::class_<RandomApplyOp, TensorOp, std::shared_ptr<RandomApplyOp>>(*m, "RandomApplyOp")
|
||||||
|
.def(py::init([](double prob, const py::list &ops) {
|
||||||
|
std::vector<std::shared_ptr<TensorOp>> t_ops;
|
||||||
|
THROW_IF_ERROR(PyListToTensorOps(ops, &t_ops));
|
||||||
|
if (prob < 0 || prob > 1) {
|
||||||
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be within [0,1]."));
|
||||||
|
}
|
||||||
|
return std::make_shared<RandomApplyOp>(prob, t_ops);
|
||||||
|
}));
|
||||||
|
(void)py::class_<RandomSelectSubpolicyOp, TensorOp, std::shared_ptr<RandomSelectSubpolicyOp>>(
|
||||||
|
*m, "RandomSelectSubpolicyOp")
|
||||||
|
.def(py::init([](const py::list &py_policy) {
|
||||||
|
std::vector<Subpolicy> cpp_policy;
|
||||||
|
for (auto &py_sub : py_policy) {
|
||||||
|
cpp_policy.push_back({});
|
||||||
|
for (auto handle : py_sub.cast<py::list>()) {
|
||||||
|
py::tuple tp = handle.cast<py::tuple>();
|
||||||
|
if (tp.is_none() || tp.size() != 2) {
|
||||||
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "Each tuple in subpolicy should be (op, prob)."));
|
||||||
|
}
|
||||||
|
std::shared_ptr<TensorOp> t_op;
|
||||||
|
if (py::isinstance<TensorOp>(tp[0])) {
|
||||||
|
t_op = (tp[0]).cast<std::shared_ptr<TensorOp>>();
|
||||||
|
} else if (py::isinstance<py::function>(tp[0])) {
|
||||||
|
t_op = std::make_shared<PyFuncOp>((tp[0]).cast<py::function>());
|
||||||
|
} else {
|
||||||
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "op is neither a tensorOp nor a pyfunc."));
|
||||||
|
}
|
||||||
|
double prob = (tp[1]).cast<py::float_>();
|
||||||
|
if (prob < 0 || prob > 1) {
|
||||||
|
THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "prob needs to be with [0,1]."));
|
||||||
|
}
|
||||||
|
cpp_policy.back().emplace_back(std::make_pair(t_op, prob));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_shared<RandomSelectSubpolicyOp>(cpp_policy);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
// This is where we externalize the C logic as python modules
|
// This is where we externalize the C logic as python modules
|
||||||
PYBIND11_MODULE(_c_dataengine, m) {
|
PYBIND11_MODULE(_c_dataengine, m) {
|
||||||
m.doc() = "pybind11 for _c_dataengine";
|
m.doc() = "pybind11 for _c_dataengine";
|
||||||
|
@ -949,6 +1025,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
||||||
bindVocabObjects(&m);
|
bindVocabObjects(&m);
|
||||||
bindGraphData(&m);
|
bindGraphData(&m);
|
||||||
bindDependIcuTokenizerOps(&m);
|
bindDependIcuTokenizerOps(&m);
|
||||||
|
bindRandomTransformTensorOps(&m);
|
||||||
}
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -4,11 +4,16 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
if (ENABLE_PYTHON)
|
if (ENABLE_PYTHON)
|
||||||
add_library(kernels OBJECT
|
add_library(kernels OBJECT
|
||||||
|
compose_op.cc
|
||||||
|
random_apply_op.cc
|
||||||
|
random_choice_op.cc
|
||||||
py_func_op.cc
|
py_func_op.cc
|
||||||
tensor_op.cc)
|
tensor_op.cc)
|
||||||
target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS})
|
target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||||
else()
|
else()
|
||||||
add_library(kernels OBJECT
|
add_library(kernels OBJECT
|
||||||
|
compose_op.cc
|
||||||
|
random_apply_op.cc
|
||||||
|
random_choice_op.cc
|
||||||
tensor_op.cc)
|
tensor_op.cc)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/compose_op.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/py_func_op.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status ComposeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
std::vector<TensorShape> in_shapes = inputs;
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
RETURN_IF_NOT_OK(op->OutputShape(in_shapes, outputs));
|
||||||
|
in_shapes = std::move(outputs); // outputs become empty after move
|
||||||
|
}
|
||||||
|
outputs = std::move(in_shapes);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status ComposeOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
|
std::vector<DataType> in_types = inputs;
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
RETURN_IF_NOT_OK(op->OutputType(in_types, outputs));
|
||||||
|
in_types = std::move(outputs); // outputs become empty after move
|
||||||
|
}
|
||||||
|
outputs = std::move(in_types);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status ComposeOp::Compute(const TensorRow &inputs, TensorRow *outputs) {
|
||||||
|
IO_CHECK_VECTOR(inputs, outputs);
|
||||||
|
TensorRow in_rows = inputs;
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
RETURN_IF_NOT_OK(op->Compute(in_rows, outputs));
|
||||||
|
in_rows = std::move(*outputs); // after move, *outputs become empty
|
||||||
|
}
|
||||||
|
(*outputs) = std::move(in_rows);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
ComposeOp::ComposeOp(const std::vector<std::shared_ptr<TensorOp>> &ops) : ops_(ops) {
|
||||||
|
if (ops_.empty()) {
|
||||||
|
MS_LOG(ERROR) << "op_list is empty this might lead to Segmentation Fault.";
|
||||||
|
} else if (ops_.size() == 1) {
|
||||||
|
MS_LOG(WARNING) << "op_list has only 1 op. Compose is probably not needed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_KERNELS_COMPOSE_OP_
|
||||||
|
#define DATASET_KERNELS_COMPOSE_OP_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class ComposeOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// constructor
|
||||||
|
/// \param[in] ops list of TensorOps to compose into 1 TensorOp
|
||||||
|
explicit ComposeOp(const std::vector<std::shared_ptr<TensorOp>> &ops);
|
||||||
|
|
||||||
|
/// default destructor
|
||||||
|
~ComposeOp() override = default;
|
||||||
|
|
||||||
|
/// return the number of inputs the first tensorOp in compose takes
|
||||||
|
/// \return number of input tensors
|
||||||
|
uint32_t NumInput() override { return ops_.front()->NumInput(); }
|
||||||
|
|
||||||
|
/// return the number of outputs the last tensorOp in compose produces
|
||||||
|
/// \return number of output tensors
|
||||||
|
uint32_t NumOutput() override { return ops_.back()->NumOutput(); }
|
||||||
|
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
|
/// \param[in] input
|
||||||
|
/// \param[out] output
|
||||||
|
/// \return Status code
|
||||||
|
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kComposeOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::shared_ptr<TensorOp>> ops_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_KERNELS_COMPOSE_OP_
|
|
@ -19,6 +19,7 @@ add_library(kernels-image OBJECT
|
||||||
bounding_box_augment_op.cc
|
bounding_box_augment_op.cc
|
||||||
random_resize_op.cc
|
random_resize_op.cc
|
||||||
random_rotation_op.cc
|
random_rotation_op.cc
|
||||||
|
random_select_subpolicy_op.cc
|
||||||
random_vertical_flip_op.cc
|
random_vertical_flip_op.cc
|
||||||
random_vertical_flip_with_bbox_op.cc
|
random_vertical_flip_with_bbox_op.cc
|
||||||
rescale_op.cc
|
rescale_op.cc
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
|
TensorRow in_row = input;
|
||||||
|
size_t rand_num = rand_int_(gen_);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(), "invalid rand_num:" + std::to_string(rand_num));
|
||||||
|
for (auto &sub : policy_[rand_num]) {
|
||||||
|
if (rand_double_(gen_) <= sub.second) {
|
||||||
|
RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
|
||||||
|
in_row = std::move(*output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*output = std::move(in_row);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t RandomSelectSubpolicyOp::NumInput() {
|
||||||
|
uint32_t num_in = policy_.front().front().first->NumInput();
|
||||||
|
for (auto &sub : policy_) {
|
||||||
|
for (auto p : sub) {
|
||||||
|
if (num_in != p.first->NumInput()) {
|
||||||
|
MS_LOG(WARNING) << "Unable to determine numInput.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_in;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t RandomSelectSubpolicyOp::NumOutput() {
|
||||||
|
uint32_t num_out = policy_.front().front().first->NumOutput();
|
||||||
|
for (auto &sub : policy_) {
|
||||||
|
for (auto p : sub) {
|
||||||
|
if (num_out != p.first->NumOutput()) {
|
||||||
|
MS_LOG(WARNING) << "Unable to determine numInput.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
|
||||||
|
for (auto &sub : policy_) {
|
||||||
|
for (auto p : sub) {
|
||||||
|
std::vector<DataType> tmp_types;
|
||||||
|
RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
|
||||||
|
if (outputs != tmp_types) {
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
|
||||||
|
: gen_(GetSeed()), policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {
|
||||||
|
if (policy_.empty()) {
|
||||||
|
MS_LOG(ERROR) << "policy in RandomSelectSubpolicyOp is empty.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_
|
||||||
|
#define DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
using Subpolicy = std::vector<std::pair<std::shared_ptr<TensorOp>, double>>;
|
||||||
|
|
||||||
|
class RandomSelectSubpolicyOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// constructor
|
||||||
|
/// \param[in] policy policy to choose subpolicy from
|
||||||
|
explicit RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy);
|
||||||
|
|
||||||
|
/// destructor
|
||||||
|
~RandomSelectSubpolicyOp() override = default;
|
||||||
|
|
||||||
|
/// return number of input tensors
|
||||||
|
/// \return number of inputs if all ops in policy have the same NumInput, otherwise return 0
|
||||||
|
uint32_t NumInput() override;
|
||||||
|
|
||||||
|
/// return number of output tensors
|
||||||
|
/// \return number of outputs if all ops in policy have the same NumOutput, otherwise return 0
|
||||||
|
uint32_t NumOutput() override;
|
||||||
|
|
||||||
|
/// return unknown shapes
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status Code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
/// return output type if all ops in policy return the same type, otherwise return unknown type
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status Code
|
||||||
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
|
/// \param[in] input
|
||||||
|
/// \param[out] output
|
||||||
|
/// \return Status code
|
||||||
|
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomSelectSubpolicyOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<Subpolicy> policy_;
|
||||||
|
std::mt19937 gen_; // mersenne_twister_engine
|
||||||
|
std::uniform_int_distribution<size_t> rand_int_;
|
||||||
|
std::uniform_real_distribution<double> rand_double_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_KERNELS_IMAGE_RANDOM_SELECT_SUBPOLICY_OP_
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/random_apply_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
uint32_t RandomApplyOp::NumOutput() {
|
||||||
|
if (compose_->NumOutput() != NumInput()) {
|
||||||
|
MS_LOG(WARNING) << "NumOutput!=NumInput (randomApply would randomly affect number of outputs).";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return compose_->NumOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomApplyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(compose_->OutputShape(inputs, outputs));
|
||||||
|
// randomApply either runs all ops or do nothing. If the two methods don't give the same result. return unknown shape.
|
||||||
|
if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status RandomApplyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(compose_->OutputType(inputs, outputs));
|
||||||
|
if (inputs != outputs) { // when RandomApply is not applied, input should be the same as output
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status RandomApplyOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
|
if (rand_double_(gen_) <= prob_) {
|
||||||
|
RETURN_IF_NOT_OK(compose_->Compute(input, output));
|
||||||
|
} else {
|
||||||
|
IO_CHECK_VECTOR(input, output);
|
||||||
|
*output = input; // copy over the tensors
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
RandomApplyOp::RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops)
|
||||||
|
: prob_(prob), gen_(GetSeed()), rand_double_(0, 1) {
|
||||||
|
compose_ = std::make_unique<ComposeOp>(ops);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_KERNELS_RANDOM_APPLY_OP_
|
||||||
|
#define DATASET_KERNELS_RANDOM_APPLY_OP_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/compose_op.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class RandomApplyOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// constructor
|
||||||
|
/// \param[in] prob probability whether the list of TensorOps will be applied
|
||||||
|
/// \param[in] ops the list of TensorOps to apply with prob likelihood
|
||||||
|
explicit RandomApplyOp(double prob, const std::vector<std::shared_ptr<TensorOp>> &ops);
|
||||||
|
|
||||||
|
/// default destructor
|
||||||
|
~RandomApplyOp() = default;
|
||||||
|
|
||||||
|
/// return the number of inputs the first tensorOp in compose takes
|
||||||
|
/// \return number of input tensors
|
||||||
|
uint32_t NumInput() override { return compose_->NumInput(); }
|
||||||
|
|
||||||
|
/// return the number of outputs
|
||||||
|
/// \return number of output tensors
|
||||||
|
uint32_t NumOutput() override;
|
||||||
|
|
||||||
|
/// return output shape if randomApply won't affect the output shape, otherwise return unknown shape
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
/// return output type if randomApply won't affect the output type, otherwise return unknown type
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
|
/// \param[in] input
|
||||||
|
/// \param[out] output
|
||||||
|
/// \return Status code
|
||||||
|
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomApplyOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
double prob_;
|
||||||
|
std::shared_ptr<TensorOp> compose_;
|
||||||
|
std::mt19937 gen_; // mersenne_twister_engine
|
||||||
|
std::uniform_real_distribution<double> rand_double_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_KERNELS_RANDOM_APPLY_OP_
|
|
@ -0,0 +1,97 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/kernels/random_choice_op.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
uint32_t RandomChoiceOp::NumInput() {
|
||||||
|
uint32_t num_input = ops_.front()->NumInput();
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
uint32_t cur_num = op->NumInput();
|
||||||
|
if (num_input != cur_num && cur_num > 0) {
|
||||||
|
MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't take the same number of input.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t RandomChoiceOp::NumOutput() {
|
||||||
|
uint32_t num_output = ops_.front()->NumOutput();
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
uint32_t cur_num = op->NumOutput();
|
||||||
|
if (num_output != cur_num) {
|
||||||
|
MS_LOG(WARNING) << "Unable to determine NumInput, ops in RandomChoice don't have the same number of input.";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
std::vector<TensorShape> out_shapes;
|
||||||
|
RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
|
||||||
|
if (outputs != out_shapes) {
|
||||||
|
MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
|
||||||
|
for (auto &op : ops_) {
|
||||||
|
std::vector<DataType> out_types;
|
||||||
|
RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
|
||||||
|
if (outputs != out_types) {
|
||||||
|
MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
|
||||||
|
outputs.clear();
|
||||||
|
outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||||
|
size_t rand_num = rand_int_(gen_);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
|
||||||
|
RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
|
||||||
|
: ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) {
|
||||||
|
if (ops_.empty()) {
|
||||||
|
MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty.";
|
||||||
|
} else if (ops_.size() == 1) {
|
||||||
|
MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_KERNELS_RANDOM_CHOICE_OP_
|
||||||
|
#define DATASET_KERNELS_RANDOM_CHOICE_OP_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
#include "minddata/dataset/kernels/compose_op.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class RandomChoiceOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// constructor
|
||||||
|
/// \param[in] ops list of TensorOps to randomly choose 1 from
|
||||||
|
explicit RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops);
|
||||||
|
|
||||||
|
/// default destructor
|
||||||
|
~RandomChoiceOp() = default;
|
||||||
|
|
||||||
|
/// return the number of inputs. All op in ops_ should have the same number of inputs
|
||||||
|
/// \return number of input tensors
|
||||||
|
uint32_t NumInput() override;
|
||||||
|
|
||||||
|
/// return the number of outputs. All op in ops_ should have the same number of outputs
|
||||||
|
/// \return number of input tensors
|
||||||
|
uint32_t NumOutput() override;
|
||||||
|
|
||||||
|
/// return output shape if all ops in ops_ return the same shape, otherwise return unknown shape
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
/// return output type if all ops in ops_ return the same type, otherwise return unknown type
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||||
|
|
||||||
|
/// \param[in] input
|
||||||
|
/// \param[out] output
|
||||||
|
/// \return Status code
|
||||||
|
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kRandomChoiceOp; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::shared_ptr<TensorOp>> ops_;
|
||||||
|
std::mt19937 gen_; // mersenne_twister_engine
|
||||||
|
std::uniform_int_distribution<size_t> rand_int_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_KERNELS_RANDOM_CHOICE_OP_
|
|
@ -129,6 +129,10 @@ constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
|
||||||
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
||||||
constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp";
|
constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp";
|
||||||
constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp";
|
constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp";
|
||||||
|
constexpr char kRandomChoiceOp[] = "RandomChoiceOp";
|
||||||
|
constexpr char kRandomApplyOp[] = "RandomApplyOp";
|
||||||
|
constexpr char kComposeOp[] = "ComposeOp";
|
||||||
|
constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp";
|
||||||
|
|
||||||
// data
|
// data
|
||||||
constexpr char kConcatenateOp[] = "kConcatenateOp";
|
constexpr char kConcatenateOp[] = "kConcatenateOp";
|
||||||
|
|
|
@ -19,6 +19,8 @@ import inspect
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore._c_dataengine as cde
|
||||||
from ..engine import samplers
|
from ..engine import samplers
|
||||||
|
|
||||||
# POS_INT_MIN is used to limit values from starting from 0
|
# POS_INT_MIN is used to limit values from starting from 0
|
||||||
|
@ -358,3 +360,9 @@ def check_gnn_list_or_ndarray(param, param_name):
|
||||||
if not param.dtype == np.int32:
|
if not param.dtype == np.int32:
|
||||||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||||
param_name, param.dtype))
|
param_name, param.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def check_tensor_op(param, param_name):
|
||||||
|
"""check whether param is a tensor op or a callable python function"""
|
||||||
|
if not isinstance(param, cde.TensorOp) and not callable(param):
|
||||||
|
raise TypeError("{0} is not a c_transform op (TensorOp) nor a callable pyfunc.".format(param_name))
|
||||||
|
|
|
@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
||||||
import mindspore._c_dataengine as cde
|
import mindspore._c_dataengine as cde
|
||||||
|
|
||||||
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \
|
from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op, check_mask_op, \
|
||||||
check_pad_end, check_concat_type
|
check_pad_end, check_concat_type, check_random_transform_ops
|
||||||
from ..core.datatypes import mstype_to_detype
|
from ..core.datatypes import mstype_to_detype
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class Slice(cde.SliceOp):
|
||||||
Maximum `n` number of arguments to slice a tensor of rank `n`.
|
Maximum `n` number of arguments to slice a tensor of rank `n`.
|
||||||
One object in slices can be one of:
|
One object in slices can be one of:
|
||||||
1. :py:obj:`int`: Slice this index only. Negative index is supported.
|
1. :py:obj:`int`: Slice this index only. Negative index is supported.
|
||||||
2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supdeported.
|
2. :py:obj:`list(int)`: Slice these indices ion the list only. Negative indices are supported.
|
||||||
3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`.
|
3. :py:obj:`slice`: Slice the generated indices from the slice object. Similar to `start:stop:step`.
|
||||||
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing.
|
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in python indexing.
|
||||||
5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing.
|
5. :py:obj:`Ellipses`: Slice all dimensions between the two slices. Similar to `...` in python indexing.
|
||||||
|
@ -232,3 +232,50 @@ class Duplicate(cde.DuplicateOp):
|
||||||
>>> # | [1,2,3] | [1,2,3] |
|
>>> # | [1,2,3] | [1,2,3] |
|
||||||
>>> # +---------+---------+
|
>>> # +---------+---------+
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Compose(cde.ComposeOp):
|
||||||
|
"""
|
||||||
|
Compose a list of transforms into a single transform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms (list): List of transformations to be applied.
|
||||||
|
Example:
|
||||||
|
>>> compose = Compose([vision.Decode(), vision.RandomCrop()])
|
||||||
|
>>> dataset = ds.map(operations=compose)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_random_transform_ops
|
||||||
|
def __init__(self, op_list):
|
||||||
|
super().__init__(op_list)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomApply(cde.RandomApplyOp):
|
||||||
|
"""
|
||||||
|
Randomly performs a series of transforms with a given probability.
|
||||||
|
Args:
|
||||||
|
transforms (list): List of transformations to be applied.
|
||||||
|
prob (float, optional): The probability to apply the transformation list (default=0.5)
|
||||||
|
Example:
|
||||||
|
>>> rand_apply = RandomApply([vision.RandomCrop()])
|
||||||
|
>>> dataset = ds.map(operations=rand_apply)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_random_transform_ops
|
||||||
|
def __init__(self, op_list, prob=0.5):
|
||||||
|
super().__init__(prob, op_list)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomChoice(cde.RandomChoiceOp):
|
||||||
|
"""
|
||||||
|
Randomly selects one transform from a list of transforms to perform operation.
|
||||||
|
Args:
|
||||||
|
transforms (list): List of transformations to be chosen from to apply.
|
||||||
|
Example:
|
||||||
|
>>> rand_choice = RandomChoice([vision.CenterCrop(), vision.RandomCrop()])
|
||||||
|
>>> dataset = ds.map(operations=rand_choice)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_random_transform_ops
|
||||||
|
def __init__(self, op_list):
|
||||||
|
super().__init__(op_list)
|
||||||
|
|
|
@ -18,7 +18,8 @@ from functools import wraps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive
|
from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive, \
|
||||||
|
check_tensor_op
|
||||||
|
|
||||||
# POS_INT_MIN is used to limit values from starting from 0
|
# POS_INT_MIN is used to limit values from starting from 0
|
||||||
POS_INT_MIN = 1
|
POS_INT_MIN = 1
|
||||||
|
@ -180,3 +181,22 @@ def check_concat_type(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_random_transform_ops(method):
|
||||||
|
"""Wrapper method to check the parameters of RandomChoice, RandomApply and Compose."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
arg_list, _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
type_check(arg_list[0], (list,), "op_list")
|
||||||
|
if not arg_list[0]:
|
||||||
|
raise ValueError("op_list can not be empty.")
|
||||||
|
for ind, op in enumerate(arg_list[0]):
|
||||||
|
check_tensor_op(op, "op_list[{0}]".format(ind))
|
||||||
|
if len(arg_list) == 2: # random apply takes an additional arg
|
||||||
|
type_check(arg_list[1], (float, int), "prob")
|
||||||
|
check_value(arg_list[1], (0, 1), "prob")
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -47,7 +47,7 @@ from .utils import Inter, Border
|
||||||
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
|
||||||
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
|
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \
|
||||||
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
|
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \
|
||||||
FLOAT_MAX_INTEGER
|
check_random_select_subpolicy_op, FLOAT_MAX_INTEGER
|
||||||
|
|
||||||
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
|
||||||
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
|
||||||
|
@ -712,3 +712,9 @@ class UniformAugment(cde.UniformAugOp):
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.num_ops = num_ops
|
self.num_ops = num_ops
|
||||||
super().__init__(operations, num_ops)
|
super().__init__(operations, num_ops)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomSelectSubpolicy(cde.RandomSelectSubpolicyOp):
|
||||||
|
@check_random_select_subpolicy_op
|
||||||
|
def __init__(self, policy):
|
||||||
|
super().__init__(policy)
|
||||||
|
|
|
@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp
|
||||||
|
|
||||||
from .utils import Inter, Border
|
from .utils import Inter, Border
|
||||||
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
|
||||||
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list
|
check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, check_tensor_op
|
||||||
|
|
||||||
|
|
||||||
def check_crop_size(size):
|
def check_crop_size(size):
|
||||||
|
@ -588,3 +588,26 @@ def check_compose_list(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_random_select_subpolicy_op(method):
|
||||||
|
"""Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[policy], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
type_check(policy, (list,), "policy")
|
||||||
|
if not policy:
|
||||||
|
raise ValueError("policy can not be empty.")
|
||||||
|
for sub_ind, sub in enumerate(policy):
|
||||||
|
type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
|
||||||
|
if not sub:
|
||||||
|
raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
|
||||||
|
for op_ind, tp in enumerate(sub):
|
||||||
|
check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
|
||||||
|
check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
|
||||||
|
check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as ops
|
||||||
|
import mindspore.dataset.transforms.py_transforms as py_ops
|
||||||
|
|
||||||
|
|
||||||
|
def test_compose():
|
||||||
|
ds.config.set_seed(0)
|
||||||
|
|
||||||
|
def test_config(arr, op_list):
|
||||||
|
try:
|
||||||
|
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
|
||||||
|
data = data.map(input_columns=["col"], operations=ops.Compose(op_list))
|
||||||
|
res = []
|
||||||
|
for i in data.create_dict_iterator():
|
||||||
|
res.append(i["col"].tolist())
|
||||||
|
return res
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
# test simple compose with only 1 op, this would generate a warning
|
||||||
|
assert test_config([[1, 0], [3, 4]], [ops.Fill(2)]) == [[2, 2], [2, 2]]
|
||||||
|
# test 1 column -> 2columns -> 1 -> 2 -> 1
|
||||||
|
assert test_config([[1, 0]], [ops.Duplicate(), ops.Concatenate(), ops.Duplicate(), ops.Concatenate()]) == [
|
||||||
|
[1, 0] * 4]
|
||||||
|
# test one python transform followed by a C transform. type after oneHot is float (mixed use-case)
|
||||||
|
assert test_config([1, 0], [py_ops.OneHotOp(2), ops.TypeCast(mstype.int32)]) == [[[0, 1]], [[1, 0]]]
|
||||||
|
# test exceptions. compose, randomApply randomChoice use the same validator
|
||||||
|
assert "op_list[0] is not a c_transform op" in test_config([1, 0], [1, ops.TypeCast(mstype.int32)])
|
||||||
|
# test empty op list
|
||||||
|
assert "op_list can not be empty." in test_config([1, 0], [])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_compose()
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as ops
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_apply():
|
||||||
|
ds.config.set_seed(0)
|
||||||
|
|
||||||
|
def test_config(arr, op_list, prob=0.5):
|
||||||
|
try:
|
||||||
|
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
|
||||||
|
data = data.map(input_columns=["col"], operations=ops.RandomApply(op_list, prob))
|
||||||
|
res = []
|
||||||
|
for i in data.create_dict_iterator():
|
||||||
|
res.append(i["col"].tolist())
|
||||||
|
return res
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
res1 = test_config([[0, 1]], [ops.Duplicate(), ops.Concatenate()])
|
||||||
|
assert res1 in [[[0, 1]], [[0, 1, 0, 1]]]
|
||||||
|
# test single nested compose
|
||||||
|
assert test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate(), ops.Slice([0, 1, 2])])]) == [
|
||||||
|
[0, 1, 2]]
|
||||||
|
# test exception
|
||||||
|
assert "is not of type (<class 'list'>" in test_config([1, 0], ops.TypeCast(mstype.int32))
|
||||||
|
assert "Input prob is not within the required interval" in test_config([0, 1], [ops.Slice([0, 1])], 1.1)
|
||||||
|
assert "is not of type (<class 'float'>" in test_config([1, 0], [ops.TypeCast(mstype.int32)], None)
|
||||||
|
assert "op_list with value None is not of type (<class 'list'>" in test_config([1, 0], None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_random_apply()
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as ops
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_choice():
|
||||||
|
ds.config.set_seed(0)
|
||||||
|
|
||||||
|
def test_config(arr, op_list):
|
||||||
|
try:
|
||||||
|
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
|
||||||
|
data = data.map(input_columns=["col"], operations=ops.RandomChoice(op_list))
|
||||||
|
res = []
|
||||||
|
for i in data.create_dict_iterator():
|
||||||
|
res.append(i["col"].tolist())
|
||||||
|
return res
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
# test whether a op would be randomly chosen. In order to prevent random failure, both results need to be checked
|
||||||
|
res1 = test_config([[0, 1, 2]], [ops.PadEnd([4], 0), ops.Slice([0, 2])])
|
||||||
|
assert res1 in [[[0, 1, 2, 0]], [[0, 2]]]
|
||||||
|
|
||||||
|
# test nested structure
|
||||||
|
res2 = test_config([[0, 1, 2]], [ops.Compose([ops.Duplicate(), ops.Concatenate()]),
|
||||||
|
ops.Compose([ops.Slice([0, 1]), ops.OneHot(2)])])
|
||||||
|
assert res2 in [[[[1, 0], [0, 1]]], [[0, 1, 2, 0, 1, 2]]]
|
||||||
|
# test random_choice where there is only 1 op
|
||||||
|
assert test_config([[4, 3], [2, 1]], [ops.Slice([0])]) == [[4], [2]]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_random_choice()
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as ops
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as visions
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_select_subpolicy():
|
||||||
|
ds.config.set_seed(0)
|
||||||
|
|
||||||
|
def test_config(arr, policy):
|
||||||
|
try:
|
||||||
|
data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
|
||||||
|
data = data.map(input_columns=["col"], operations=visions.RandomSelectSubpolicy(policy))
|
||||||
|
res = []
|
||||||
|
for i in data.create_dict_iterator():
|
||||||
|
res.append(i["col"].tolist())
|
||||||
|
return res
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
# 3 possible outcomes
|
||||||
|
policy1 = [[(ops.PadEnd([4], 0), 0.5), (ops.Compose([ops.Duplicate(), ops.Concatenate()]), 1)],
|
||||||
|
[(ops.Slice([0, 1]), 0.5), (ops.Duplicate(), 1), (ops.Concatenate(), 1)]]
|
||||||
|
res1 = test_config([[1, 2, 3]], policy1)
|
||||||
|
assert res1 in [[[1, 2, 1, 2]], [[1, 2, 3, 1, 2, 3]], [[1, 2, 3, 0, 1, 2, 3, 0]]]
|
||||||
|
|
||||||
|
# test exceptions
|
||||||
|
assert "policy can not be empty." in test_config([[1, 2, 3]], [])
|
||||||
|
assert "policy[0] can not be empty." in test_config([[1, 2, 3]], [[]])
|
||||||
|
assert "op of (op, prob) in policy[1][0] is not a c_transform op (TensorOp) nor a callable pyfunc" in test_config(
|
||||||
|
[[1, 2, 3]], [[(ops.PadEnd([4], 0), 0.5)], [(1, 0.4)]])
|
||||||
|
assert "prob of (op, prob) policy[1][0] is not within the required interval of (0 to 1)" in test_config([[1]], [
|
||||||
|
[(ops.Duplicate(), 0)], [(ops.Duplicate(), -0.1)]])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_random_select_subpolicy()
|
Loading…
Reference in New Issue