forked from mindspore-Ecosystem/mindspore
!1498 Cleanup and improvement for UniformAugOp
Merge pull request !1498 from anthonyaje/cleanup_uniform_aug
This commit is contained in:
commit
bfda2facfa
|
@ -290,7 +290,8 @@ void bindTensorOps1(py::module *m) {
|
||||||
|
|
||||||
(void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
|
(void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
|
||||||
*m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
|
*m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
|
||||||
.def(py::init<py::list, int32_t>(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps);
|
.def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
|
||||||
|
py::arg("NumOps") = UniformAugOp::kDefNumOps);
|
||||||
|
|
||||||
(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
|
(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
|
||||||
*m, "ResizeBilinearOp",
|
*m, "ResizeBilinearOp",
|
||||||
|
|
|
@ -13,23 +13,16 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include <utility>
|
||||||
#include "dataset/kernels/image/uniform_aug_op.h"
|
#include "dataset/kernels/image/uniform_aug_op.h"
|
||||||
#include "dataset/kernels/py_func_op.h"
|
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
const int UniformAugOp::kDefNumOps = 2;
|
const int UniformAugOp::kDefNumOps = 2;
|
||||||
|
|
||||||
UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) {
|
UniformAugOp::UniformAugOp(std::vector<std::shared_ptr<TensorOp>> op_list, int32_t num_ops)
|
||||||
std::shared_ptr<TensorOp> tensor_op;
|
: tensor_op_list_(op_list), num_ops_(num_ops) {
|
||||||
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
|
|
||||||
for (auto op : op_list) {
|
|
||||||
// only C++ op is accepted
|
|
||||||
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
|
|
||||||
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
rnd_.seed(GetSeed());
|
rnd_.seed(GetSeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,37 +31,28 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
|
||||||
std::vector<std::shared_ptr<Tensor>> *output) {
|
std::vector<std::shared_ptr<Tensor>> *output) {
|
||||||
IO_CHECK_VECTOR(input, output);
|
IO_CHECK_VECTOR(input, output);
|
||||||
|
|
||||||
// variables to copy the result to output if it is not already
|
|
||||||
std::vector<std::shared_ptr<Tensor>> even_out;
|
|
||||||
std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out;
|
|
||||||
int count = 1;
|
|
||||||
|
|
||||||
// randomly select ops to be applied
|
// randomly select ops to be applied
|
||||||
std::vector<std::shared_ptr<TensorOp>> selected_tensor_ops;
|
std::vector<std::shared_ptr<TensorOp>> selected_tensor_ops;
|
||||||
std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_);
|
std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_);
|
||||||
|
|
||||||
for (auto tensor_op = selected_tensor_ops.begin(); tensor_op != selected_tensor_ops.end(); ++tensor_op) {
|
bool first = true;
|
||||||
|
for (const auto &tensor_op : selected_tensor_ops) {
|
||||||
// Do NOT apply the op, if second random generator returned zero
|
// Do NOT apply the op, if second random generator returned zero
|
||||||
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) {
|
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply C++ ops (note: python OPs are not accepted)
|
// apply C++ ops (note: python OPs are not accepted)
|
||||||
if (count == 1) {
|
if (first) {
|
||||||
RETURN_IF_NOT_OK((**tensor_op).Compute(input, output));
|
RETURN_IF_NOT_OK(tensor_op->Compute(input, output));
|
||||||
} else if (count % 2 == 0) {
|
first = false;
|
||||||
RETURN_IF_NOT_OK((**tensor_op).Compute(*output, even_out_ptr));
|
|
||||||
} else {
|
} else {
|
||||||
RETURN_IF_NOT_OK((**tensor_op).Compute(even_out, output));
|
RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output));
|
||||||
}
|
}
|
||||||
count++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy the result to output if it is not in output
|
// The case where no tensor op is applied.
|
||||||
if (count == 1) {
|
if (output->empty()) {
|
||||||
*output = input;
|
*output = input;
|
||||||
} else if ((count % 2 == 1)) {
|
|
||||||
(*output).swap(even_out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -24,9 +24,6 @@
|
||||||
#include "dataset/core/tensor.h"
|
#include "dataset/core/tensor.h"
|
||||||
#include "dataset/kernels/tensor_op.h"
|
#include "dataset/kernels/tensor_op.h"
|
||||||
#include "dataset/util/status.h"
|
#include "dataset/util/status.h"
|
||||||
#include "dataset/kernels/py_func_op.h"
|
|
||||||
|
|
||||||
#include "pybind11/stl.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -36,10 +33,11 @@ class UniformAugOp : public TensorOp {
|
||||||
static const int kDefNumOps;
|
static const int kDefNumOps;
|
||||||
|
|
||||||
// Constructor for UniformAugOp
|
// Constructor for UniformAugOp
|
||||||
// @param list op_list: list of candidate C++ operations
|
// @param std::vector<std::shared_ptr<TensorOp>> op_list: list of candidate C++ operations
|
||||||
// @param list num_ops: number of augemtation operations to applied
|
// @param int32_t num_ops: number of augemtation operations to applied
|
||||||
UniformAugOp(py::list op_list, int32_t num_ops);
|
UniformAugOp(std::vector<std::shared_ptr<TensorOp>> op_list, int32_t num_ops);
|
||||||
|
|
||||||
|
// Destructor
|
||||||
~UniformAugOp() override = default;
|
~UniformAugOp() override = default;
|
||||||
|
|
||||||
void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; }
|
void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; }
|
||||||
|
|
Loading…
Reference in New Issue