!3024 Decode + RandomCropAndResize fusion within MapOp
Merge pull request !3024 from Alexey_Shevlyakov/random_crop_decode_resize_fusion
This commit is contained in:
commit
530d46eb47
|
@ -27,6 +27,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Forward declare
|
||||
class ExecutionTree;
|
||||
|
||||
|
|
|
@ -181,6 +181,13 @@ class MapOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "MapOp"; }
|
||||
|
||||
// List of tensor ops getter/setter
|
||||
// @Return the vector of tensor ops by non-const reference
|
||||
|
||||
auto &TFuncs() { return tfuncs_; }
|
||||
|
||||
const auto &TFuncs() const { return tfuncs_; }
|
||||
|
||||
private:
|
||||
// Local queues where worker threads can pop from.
|
||||
// Popping directly from the Connector can block if the previous designated threads haven't pop.
|
||||
|
@ -188,7 +195,7 @@ class MapOp : public ParallelOp {
|
|||
QueueList<std::unique_ptr<DataBuffer>> local_queues_;
|
||||
|
||||
// Static variables to be ready by worker threads, no modification and readonly
|
||||
const std::vector<std::shared_ptr<TensorOp>> tfuncs_;
|
||||
std::vector<std::shared_ptr<TensorOp>> tfuncs_;
|
||||
|
||||
// Variable to store the column name that the tensorOps are consuming
|
||||
std::vector<std::string> in_columns_;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "dataset/engine/opt/post/repeat_pass.h"
|
||||
#include "mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "dataset/engine/perf/profiling.h"
|
||||
#include "dataset/engine/perf/monitor.h"
|
||||
|
||||
|
@ -35,6 +36,7 @@ ExecutionTree::ExecutionTree() : id_count_(0) {
|
|||
prepare_flags_ = kDePrepNone;
|
||||
perf_monitor_ = std::make_unique<Monitor>(this);
|
||||
profiling_manager_ = std::make_unique<ProfilingManager>(this);
|
||||
optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false;
|
||||
}
|
||||
|
||||
// Destructor
|
||||
|
@ -202,8 +204,10 @@ Status ExecutionTree::Prepare() {
|
|||
// Pre optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
|
||||
|
||||
// Optimization transformation
|
||||
RETURN_IF_NOT_OK(this->Optimize());
|
||||
// If optional optimizations are enabled
|
||||
if (optimize_) {
|
||||
RETURN_IF_NOT_OK(this->Optimize());
|
||||
}
|
||||
|
||||
// Post optimization compulsory transformation
|
||||
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
|
||||
|
@ -248,9 +252,16 @@ Status ExecutionTree::PrepareTreePostAction() {
|
|||
}
|
||||
|
||||
Status ExecutionTree::Optimize() {
|
||||
// auto pp = new PrinterPass();
|
||||
// bool modified = false;
|
||||
// pp->Run(this, &modified);
|
||||
// Vector of optimizations, currently only 1, add more as necessary
|
||||
std::vector<std::unique_ptr<NodePass>> optimizations;
|
||||
optimizations.push_back(std::make_unique<TensorOpFusionPass>());
|
||||
// vector of flags for each optimization
|
||||
std::vector<bool> modified(optimizations.size(), false);
|
||||
for (auto i = 0; i < optimizations.size(); i++) {
|
||||
auto m = false;
|
||||
optimizations[i]->Run(this, &m);
|
||||
modified[i] = m;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -87,6 +87,8 @@ class ExecutionTree {
|
|||
// @return Shared pointer to the current operator
|
||||
std::shared_ptr<DatasetOp> get() { return nodes_[ind_]; }
|
||||
|
||||
bool operator==(const Iterator &rhs) { return nodes_[ind_] == rhs.nodes_[rhs.ind_]; }
|
||||
|
||||
bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; }
|
||||
|
||||
int32_t NumNodes() { return nodes_.size(); }
|
||||
|
@ -214,6 +216,21 @@ class ExecutionTree {
|
|||
// Getter for profiling manager, no ownership
|
||||
ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
|
||||
|
||||
// Set optional optimization if tree has not been prepared yet
|
||||
Status SetOptimize(bool value) {
|
||||
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
|
||||
std::string optimize = (optimize_ == true) ? "true" : "false";
|
||||
std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize;
|
||||
RETURN_STATUS_UNEXPECTED(msg);
|
||||
} else {
|
||||
optimize_ = value;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Optional optimizations status
|
||||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
private:
|
||||
// A helper functions for doing the recursive printing
|
||||
// @param dataset_op - The dataset op to print
|
||||
|
@ -230,7 +247,10 @@ class ExecutionTree {
|
|||
TreeState tree_state_; // Tracking the current tree state
|
||||
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
|
||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||
bool optimize_; // Flag to enable optional optimizations
|
||||
};
|
||||
|
||||
inline bool operator==(const ExecutionTree::Iterator &lhs, const ExecutionTree::Iterator &rhs) { return lhs == rhs; }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -7,5 +7,6 @@ add_library(engine-opt OBJECT
|
|||
pre/cache_transform_pass.cc
|
||||
pre/removal_nodes.cc
|
||||
pre/removal_pass.cc
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
util/printer_pass.cc
|
||||
)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "dataset/kernels/image/decode_op.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status TensorOpFusionPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
// Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp
|
||||
// Abstract into a more general member function that can find any pattern, expressed
|
||||
// by regular expressions, for instance.
|
||||
// Add a list of optimisation policies. For now, just this lambda
|
||||
auto FindPattern = [](auto &tfuncs) {
|
||||
auto it =
|
||||
std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; });
|
||||
auto next = it + 1;
|
||||
if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) {
|
||||
return it;
|
||||
} else {
|
||||
return tfuncs.end();
|
||||
}
|
||||
};
|
||||
|
||||
auto &tfuncs = node->TFuncs();
|
||||
auto it = FindPattern(tfuncs);
|
||||
if (it != tfuncs.end()) {
|
||||
auto next = it + 1;
|
||||
auto op = static_cast<RandomCropAndResizeOp *>(next->get());
|
||||
*it = std::static_pointer_cast<TensorOp>(std::make_shared<RandomCropDecodeResizeOp>(*op));
|
||||
tfuncs.erase(next);
|
||||
}
|
||||
if (modified != nullptr) {
|
||||
*modified = true;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("modified is nullptr");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_
|
||||
#define DATASET_TENSOR_OP_FUSION_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class TensorOpFusionPass tensor_op_fusion_pass.h
|
||||
/// \brief And optional optimization pass identifying and fusing
|
||||
/// tensor ops within MapOp
|
||||
class TensorOpFusionPass : public NodePass {
|
||||
/// \brief Identifies and fuses tensor ops within MapOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] *modified indicates whether the node has been visited
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_TENSOR_OP_FUSION_PASS_H_
|
|
@ -55,6 +55,8 @@ class ConcatenateOp : public TensorOp {
|
|||
/// Number of inputs the tensor operation accepts
|
||||
uint32_t NumInput() override { return 0; }
|
||||
|
||||
std::string Name() const override { return kConcatenateOp; }
|
||||
|
||||
private:
|
||||
int8_t axis_;
|
||||
std::shared_ptr<Tensor> prepend_;
|
||||
|
|
|
@ -127,7 +127,7 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
|
|||
std::shared_ptr<Tensor> out, fill_output;
|
||||
|
||||
if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
|
||||
std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
|
||||
auto op = std::make_unique<TypeCastOp>(input_type);
|
||||
RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
|
||||
} else {
|
||||
fill_output = fill_value;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -36,6 +37,8 @@ class DuplicateOp : public TensorOp {
|
|||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
uint32_t NumOutput() override { return 2; }
|
||||
|
||||
std::string Name() const override { return kDuplicateOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,8 @@ class FillOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kFillOp; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Tensor> fill_value_;
|
||||
};
|
||||
|
|
|
@ -43,6 +43,8 @@ class MaskOp : public TensorOp {
|
|||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kMaskOp; }
|
||||
|
||||
private:
|
||||
RelationalOp op_;
|
||||
std::shared_ptr<Tensor> value_;
|
||||
|
|
|
@ -37,6 +37,8 @@ class OneHotOp : public TensorOp {
|
|||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kOneHotOp; }
|
||||
|
||||
private:
|
||||
int num_classes_;
|
||||
};
|
||||
|
|
|
@ -38,6 +38,8 @@ class PadEndOp : public TensorOp {
|
|||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kPadEndOp; }
|
||||
|
||||
private:
|
||||
TensorShape output_shape_;
|
||||
std::shared_ptr<Tensor> pad_val_;
|
||||
|
|
|
@ -71,6 +71,8 @@ class SliceOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kSliceOp; }
|
||||
|
||||
private:
|
||||
// only on of the following will be valid
|
||||
// given indices to slice the Tensor. Empty vector if invalid.
|
||||
|
|
|
@ -42,6 +42,8 @@ class ToFloat16Op : public TensorOp {
|
|||
void Print(std::ostream &out) const override { out << "ToFloat16Op"; }
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kToFloat16Op; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,8 @@ class TypeCastOp : public TensorOp {
|
|||
void Print(std::ostream &out) const override { out << "TypeCastOp"; }
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kTypeCastOp; }
|
||||
|
||||
private:
|
||||
DataType type_;
|
||||
};
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -50,6 +51,8 @@ class BoundingBoxAugmentOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kBoundingBoxAugmentOp; }
|
||||
|
||||
private:
|
||||
float ratio_;
|
||||
std::mt19937 rnd_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -39,6 +40,8 @@ class CenterCropOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kCenterCropOp; }
|
||||
|
||||
private:
|
||||
int32_t crop_het_;
|
||||
int32_t crop_wid_;
|
||||
|
|
|
@ -61,6 +61,8 @@ class CutOutOp : public TensorOp {
|
|||
// @return Status - The error code return
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kCutOutOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
int32_t box_height_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -40,6 +41,8 @@ class DecodeOp : public TensorOp {
|
|||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kDecodeOp; }
|
||||
|
||||
private:
|
||||
bool is_rgb_format_ = true;
|
||||
};
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -31,6 +32,8 @@ class HwcToChwOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kHwcToChwOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -311,7 +311,7 @@ Status JpegCropAndDecode(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
|
|||
TensorShape ts = TensorShape({crop_h, crop_w, kOutNumComponents});
|
||||
auto output_tensor = std::make_shared<Tensor>(ts, DataType(DataType::DE_UINT8));
|
||||
const int buffer_size = output_tensor->SizeInBytes();
|
||||
JSAMPLE *buffer = static_cast<JSAMPLE *>(reinterpret_cast<uchar *>(&(*output_tensor->begin<uint8_t>())));
|
||||
JSAMPLE *buffer = reinterpret_cast<JSAMPLE *>(&(*output_tensor->begin<uint8_t>()));
|
||||
const int max_scanlines_to_read = skipped_scanlines + crop_h;
|
||||
// stride refers to output tensor, which has 3 components at most
|
||||
const int stride = crop_w * kOutNumComponents;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/cv_tensor.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
|
@ -35,6 +36,8 @@ class NormalizeOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kNormalizeOp; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CVTensor> mean_;
|
||||
std::shared_ptr<CVTensor> std_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -53,6 +54,8 @@ class PadOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kPadOp; }
|
||||
|
||||
private:
|
||||
int32_t pad_top_;
|
||||
int32_t pad_bottom_;
|
||||
|
|
|
@ -57,6 +57,8 @@ class RandomColorAdjustOp : public TensorOp {
|
|||
// @return Status - The error code return.
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomColorAdjustOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
float bright_factor_start_;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/image/image_utils.h"
|
||||
|
@ -41,6 +42,12 @@ class RandomCropAndResizeOp : public TensorOp {
|
|||
float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb,
|
||||
InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter);
|
||||
|
||||
RandomCropAndResizeOp() = default;
|
||||
|
||||
RandomCropAndResizeOp(const RandomCropAndResizeOp &rhs) = default;
|
||||
|
||||
RandomCropAndResizeOp(RandomCropAndResizeOp &&rhs) = default;
|
||||
|
||||
~RandomCropAndResizeOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
|
@ -52,6 +59,8 @@ class RandomCropAndResizeOp : public TensorOp {
|
|||
|
||||
Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width);
|
||||
|
||||
std::string Name() const override { return kRandomCropAndResizeOp; }
|
||||
|
||||
protected:
|
||||
int32_t target_height_;
|
||||
int32_t target_width_;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_
|
||||
|
||||
#include "dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -39,6 +40,8 @@ class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp {
|
|||
}
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomCropAndResizeWithBBoxOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,8 @@ class RandomCropDecodeResizeOp : public RandomCropAndResizeOp {
|
|||
float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb,
|
||||
InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter);
|
||||
|
||||
explicit RandomCropDecodeResizeOp(const RandomCropAndResizeOp &rhs) : RandomCropAndResizeOp(rhs) {}
|
||||
|
||||
~RandomCropDecodeResizeOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
|
@ -43,6 +45,8 @@ class RandomCropDecodeResizeOp : public RandomCropAndResizeOp {
|
|||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomCropDecodeResizeOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -45,6 +46,10 @@ class RandomCropOp : public TensorOp {
|
|||
BorderType border_types = kDefBorderType, bool pad_if_needed = kDefPadIfNeeded,
|
||||
uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB);
|
||||
|
||||
RandomCropOp(const RandomCropOp &rhs) = default;
|
||||
|
||||
RandomCropOp(RandomCropOp &&rhs) = default;
|
||||
|
||||
~RandomCropOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; }
|
||||
|
@ -72,6 +77,8 @@ class RandomCropOp : public TensorOp {
|
|||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kRandomCropOp; }
|
||||
|
||||
protected:
|
||||
int32_t crop_height_ = 0;
|
||||
int32_t crop_width_ = 0;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/kernels/image/random_crop_op.h"
|
||||
|
||||
|
@ -41,6 +42,8 @@ class RandomCropWithBBoxOp : public RandomCropOp {
|
|||
}
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomCropWithBBoxOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -47,6 +48,8 @@ class RandomHorizontalFlipOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomHorizontalFlipOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
std::bernoulli_distribution distribution_;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -48,6 +49,8 @@ class RandomHorizontalFlipWithBBoxOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomHorizontalFlipWithBBoxOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
std::bernoulli_distribution distribution_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/image/resize_op.h"
|
||||
|
@ -45,6 +46,8 @@ class RandomResizeOp : public ResizeOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomResizeOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 random_generator_;
|
||||
std::uniform_int_distribution<int> distribution_{0, 3};
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/image/resize_op.h"
|
||||
|
@ -46,6 +47,8 @@ class RandomResizeWithBBoxOp : public ResizeWithBBoxOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomResizeWithBBoxOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 random_generator_;
|
||||
std::uniform_int_distribution<int> distribution_{0, 3};
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -68,6 +69,8 @@ class RandomRotationOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kRandomRotationOp; }
|
||||
|
||||
private:
|
||||
float degree_start_;
|
||||
float degree_end_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -41,6 +42,8 @@ class RandomVerticalFlipOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomVerticalFlipOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
std::bernoulli_distribution distribution_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -42,6 +43,8 @@ class RandomVerticalFlipWithBBoxOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomVerticalFlipWithBBoxOp; }
|
||||
|
||||
private:
|
||||
std::mt19937 rnd_;
|
||||
std::bernoulli_distribution distribution_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -38,6 +39,8 @@ class RescaleOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kRescaleOp; }
|
||||
|
||||
private:
|
||||
float rescale_;
|
||||
float shift_;
|
||||
|
|
|
@ -51,6 +51,8 @@ class ResizeBilinearOp : public ResizeOp {
|
|||
// Name: Print()
|
||||
// Description: A function that prints info about the node
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
std::string Name() const override { return kResizeBilinearOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/image/image_utils.h"
|
||||
|
@ -43,6 +44,10 @@ class ResizeOp : public TensorOp {
|
|||
explicit ResizeOp(int32_t size1, int32_t size2 = kDefWidth, InterpolationMode mInterpolation = kDefInterpolation)
|
||||
: size1_(size1), size2_(size2), interpolation_(mInterpolation) {}
|
||||
|
||||
ResizeOp(const ResizeOp &rhs) = default;
|
||||
|
||||
ResizeOp(ResizeOp &&rhs) = default;
|
||||
|
||||
~ResizeOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; }
|
||||
|
@ -50,6 +55,8 @@ class ResizeOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kResizeOp; }
|
||||
|
||||
protected:
|
||||
int32_t size1_;
|
||||
int32_t size2_;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H
|
||||
#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H
|
||||
|
||||
#include <string>
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/image/image_utils.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -36,6 +37,8 @@ class ResizeWithBBoxOp : public ResizeOp {
|
|||
void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; }
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kResizeWithBBoxOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,8 @@ class UniformAugOp : public TensorOp {
|
|||
// @return Status - The error code return
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kUniformAugOp; }
|
||||
|
||||
private:
|
||||
int32_t num_ops_;
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_op_list_;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define DATASET_KERNELS_NO_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -31,6 +32,8 @@ class NoOp : public TensorOp {
|
|||
}
|
||||
|
||||
void Print(std::ostream &out) const override { out << "NoOp"; };
|
||||
|
||||
std::string Name() const override { return kNoOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -38,6 +39,8 @@ class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp {
|
|||
// Compute function for n-n mapping.
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kPyFuncOp; }
|
||||
|
||||
private:
|
||||
py::function py_func_ptr_;
|
||||
};
|
||||
|
|
|
@ -85,6 +85,66 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// image
|
||||
constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
|
||||
constexpr char kDecodeOp[] = "DecodeOp";
|
||||
constexpr char kCenterCropOp[] = "CenterCropOp";
|
||||
constexpr char kCutOutOp[] = "CutOutOp";
|
||||
constexpr char kHwcToChwOp[] = "HwcToChwOp";
|
||||
constexpr char kNormalizeOp[] = "NormalizeOp";
|
||||
constexpr char kPadOp[] = "PadOp";
|
||||
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
|
||||
constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";
|
||||
constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp";
|
||||
constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp";
|
||||
constexpr char kRandomCropOp[] = "RandomCropOp";
|
||||
constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp";
|
||||
constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
|
||||
constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
|
||||
constexpr char kRandomResizeOp[] = "RandomResizeOp";
|
||||
constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
|
||||
constexpr char kRandomRotationOp[] = "RandomRotationOp";
|
||||
constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
|
||||
constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
|
||||
constexpr char kRescaleOp[] = "RescaleOp";
|
||||
constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
|
||||
constexpr char kResizeOp[] = "ResizeOp";
|
||||
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
|
||||
constexpr char kUniformAugOp[] = "UniformAugOp";
|
||||
|
||||
// text
|
||||
constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";
|
||||
constexpr char kBertTokenizerOp[] = "BertTokenizerOp";
|
||||
constexpr char kCaseFoldOp[] = "CaseFoldOp";
|
||||
constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
|
||||
constexpr char kLookupOp[] = "LookupOp";
|
||||
constexpr char kNgramOp[] = "NgramOp";
|
||||
constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
|
||||
constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
|
||||
constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
|
||||
constexpr char kToNumberOp[] = "ToNumberOp";
|
||||
constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp";
|
||||
constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
|
||||
constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
|
||||
constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp";
|
||||
constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp";
|
||||
|
||||
// data
|
||||
constexpr char kConcatenateOp[] = "kConcatenateOp";
|
||||
constexpr char kDuplicateOp[] = "DuplicateOp";
|
||||
constexpr char kFillOp[] = "FillOp";
|
||||
constexpr char kMaskOp[] = "MaskOp";
|
||||
constexpr char kOneHotOp[] = "OneHotOp";
|
||||
constexpr char kPadEndOp[] = "PadEndOp";
|
||||
constexpr char kSliceOp[] = "SliceOp";
|
||||
constexpr char kToFloat16Op[] = "ToFloat16Op";
|
||||
constexpr char kTypeCastOp[] = "TypeCastOp";
|
||||
|
||||
// other
|
||||
constexpr char kPyFuncOp[] = "PyFuncOp";
|
||||
constexpr char kNoOp[] = "NoOp";
|
||||
|
||||
// A class that does a computation on a Tensor
|
||||
class TensorOp {
|
||||
public:
|
||||
|
@ -143,6 +203,8 @@ class TensorOp {
|
|||
// @param outputs out: vector of the types of the output tensors to be filled.
|
||||
// @return Status
|
||||
virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs);
|
||||
|
||||
virtual std::string Name() const = 0;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -54,6 +54,8 @@ class BasicTokenizerOp : public TensorOp {
|
|||
std::string *outupt);
|
||||
Status CaseFoldWithoutUnusedWords(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
std::string Name() const override { return kBasicTokenizerOp; }
|
||||
|
||||
private:
|
||||
static const char kCommonPattern[];
|
||||
static const char kUnusedPattern[];
|
||||
|
|
|
@ -46,6 +46,8 @@ class BertTokenizerOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kBertTokenizerOp; }
|
||||
|
||||
private:
|
||||
WordpieceTokenizerOp wordpiece_tokenizer_;
|
||||
BasicTokenizerOp basic_tokenizer_;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -33,6 +34,8 @@ class CaseFoldOp : public TensorOp {
|
|||
void Print(std::ostream &out) const override { out << "CaseFoldOp"; }
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kCaseFoldOp; }
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,6 +57,8 @@ class JiebaTokenizerOp : public TensorOp {
|
|||
// @tag [Default ""] the tag of the word to be added.
|
||||
Status AddWord(const std::string &word, int freq = 0);
|
||||
|
||||
std::string Name() const override { return kJiebaTokenizerOp; }
|
||||
|
||||
protected:
|
||||
std::string hmm_model_path_;
|
||||
std::string mp_dict_path_;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -52,6 +53,8 @@ class LookupOp : public TensorOp {
|
|||
// @return error code
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kLookupOp; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
WordIdType default_id_;
|
||||
|
|
|
@ -58,6 +58,8 @@ class NgramOp : public TensorOp {
|
|||
// @param std::ostream &out
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
std::string Name() const override { return kNgramOp; }
|
||||
|
||||
private:
|
||||
std::vector<int32_t> ngrams_; // list of n grams
|
||||
int32_t l_len_; // left padding length
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -42,6 +43,8 @@ class NormalizeUTF8Op : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kNormalizeUTF8Op; }
|
||||
|
||||
private:
|
||||
NormalizeForm normalize_form_;
|
||||
};
|
||||
|
|
|
@ -42,6 +42,8 @@ class RegexReplaceOp : public TensorOp {
|
|||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRegexReplaceOp; }
|
||||
|
||||
protected:
|
||||
Status RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, std::string *out) const;
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ class RegexTokenizerOp : public TensorOp {
|
|||
Status GetRegexTokens(const std::string &text, std::vector<std::string> *out_tokens,
|
||||
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) const;
|
||||
|
||||
std::string Name() const override { return kRegexTokenizerOp; }
|
||||
|
||||
private:
|
||||
const icu::UnicodeString delim_pattern_;
|
||||
const icu::UnicodeString keep_delim_pattern_;
|
||||
|
|
|
@ -57,6 +57,8 @@ class ToNumberOp : public TensorOp {
|
|||
// @param std::ostream &out
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
std::string Name() const override { return kToNumberOp; }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
|
|
@ -40,6 +40,8 @@ class TruncateSequencePairOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kTruncateSequencePairOp; }
|
||||
|
||||
private:
|
||||
dsize_t max_length_;
|
||||
};
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -36,6 +37,8 @@ class UnicodeCharTokenizerOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kUnicodeCharTokenizerOp; }
|
||||
|
||||
private:
|
||||
bool with_offsets_;
|
||||
};
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -39,6 +40,8 @@ class UnicodeScriptTokenizerOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kUnicodeScriptTokenizerOp; }
|
||||
|
||||
private:
|
||||
bool keep_whitespace_; // If or not keep whitespace tokens
|
||||
bool with_offsets_;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
|
@ -36,6 +37,8 @@ class WhitespaceTokenizerOp : public TensorOp {
|
|||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kWhitespaceTokenizerOp; }
|
||||
|
||||
private:
|
||||
bool with_offsets_;
|
||||
};
|
||||
|
|
|
@ -58,6 +58,8 @@ class WordpieceTokenizerOp : public TensorOp {
|
|||
Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector<std::string> *out_tokens,
|
||||
std::vector<uint32_t> *offsets_start, std::vector<uint32_t> *offsets_limit) const;
|
||||
|
||||
std::string Name() const override { return kWordpieceTokenizerOp; }
|
||||
|
||||
private:
|
||||
const std::shared_ptr<Vocab> vocab_;
|
||||
const std::string suffix_indicator_;
|
||||
|
|
|
@ -55,7 +55,7 @@ SET(DE_UT_SRCS
|
|||
resize_bilinear_op_test.cc
|
||||
resize_op_test.cc
|
||||
resize_with_bbox_op_test.cc
|
||||
schema_test.cc
|
||||
schema_test.cc
|
||||
shuffle_op_test.cc
|
||||
stand_alone_samplers_test.cc
|
||||
status_test.cc
|
||||
|
@ -91,6 +91,7 @@ SET(DE_UT_SRCS
|
|||
cyclic_array_test.cc
|
||||
perf_data_test.cc
|
||||
c_api_test.cc
|
||||
tensor_op_fusion_pass_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
||||
#include "common/common.h"
|
||||
#include "dataset/core/client.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
|
@ -35,93 +36,99 @@ namespace dataset {
|
|||
namespace test {
|
||||
class NoOp : public TensorOp {
|
||||
public:
|
||||
NoOp() {};
|
||||
NoOp(){};
|
||||
|
||||
~NoOp() {};
|
||||
~NoOp(){};
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override {
|
||||
*output = std::move(input);
|
||||
return Status::OK();
|
||||
};
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override {
|
||||
*output = std::move(input);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
void Print(std::ostream &out) const override { out << "NoOp"; };
|
||||
void Print(std::ostream &out) const override { out << "NoOp"; };
|
||||
|
||||
std::string Name() const override { return kNoOp; }
|
||||
};
|
||||
|
||||
class ThreeToOneOp : public TensorOp {
|
||||
public:
|
||||
ThreeToOneOp() {};
|
||||
ThreeToOneOp(){};
|
||||
|
||||
~ThreeToOneOp() {};
|
||||
~ThreeToOneOp(){};
|
||||
|
||||
uint32_t NumInput() override { return 3; }
|
||||
// Compute function that holds the actual implementation of the operation.
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override {
|
||||
output->push_back(input[0]);
|
||||
return Status::OK();
|
||||
};
|
||||
uint32_t NumInput() override { return 3; }
|
||||
// Compute function that holds the actual implementation of the operation.
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override {
|
||||
output->push_back(input[0]);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
void Print(std::ostream &out) const override { out << "ThreeToOneOp"; };
|
||||
void Print(std::ostream &out) const override { out << "ThreeToOneOp"; };
|
||||
|
||||
std::string Name() const override { return "ThreeToOneOp"; }
|
||||
};
|
||||
|
||||
class OneToThreeOp : public TensorOp {
|
||||
public:
|
||||
OneToThreeOp() {};
|
||||
OneToThreeOp(){};
|
||||
|
||||
~OneToThreeOp() {};
|
||||
~OneToThreeOp(){};
|
||||
|
||||
uint32_t NumOutput() override { return 3; }
|
||||
|
||||
// Compute function that holds the actual implementation of the operation.
|
||||
// Simply pushing the same shared pointer of the first element of input vector three times.
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override {
|
||||
output->push_back(input[0]);
|
||||
output->push_back(input[0]);
|
||||
output->push_back(input[0]);
|
||||
return Status::OK();
|
||||
};
|
||||
// Compute function that holds the actual implementation of the operation.
|
||||
// Simply pushing the same shared pointer of the first element of input vector three times.
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override {
|
||||
output->push_back(input[0]);
|
||||
output->push_back(input[0]);
|
||||
output->push_back(input[0]);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
void Print(std::ostream &out) const override { out << "OneToThreeOp"; };
|
||||
void Print(std::ostream &out) const override { out << "OneToThreeOp"; };
|
||||
|
||||
std::string Name() const override { return "OneToThreeOp"; };
|
||||
};
|
||||
} // namespace test
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
||||
class MindDataTestMapOp : public UT::DatasetOpTesting {
|
||||
public:
|
||||
void SetUp() override {
|
||||
DatasetOpTesting::SetUp();
|
||||
dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data";
|
||||
schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json";
|
||||
void SetUp() override {
|
||||
DatasetOpTesting::SetUp();
|
||||
dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data";
|
||||
schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json";
|
||||
|
||||
GlobalInit();
|
||||
GlobalInit();
|
||||
|
||||
// Start with an empty execution tree
|
||||
my_tree_ = std::make_shared<ExecutionTree>();
|
||||
}
|
||||
// Start with an empty execution tree
|
||||
my_tree_ = std::make_shared<ExecutionTree>();
|
||||
}
|
||||
|
||||
std::shared_ptr<TFReaderOp> CreateTFReaderOp() {
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path_})
|
||||
.SetColumnsToLoad({"image", "label", "A", "B"})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(2)
|
||||
.SetNumWorkers(2);
|
||||
std::shared_ptr<TFReaderOp> CreateTFReaderOp() {
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path_})
|
||||
.SetColumnsToLoad({"image", "label", "A", "B"})
|
||||
.SetRowsPerBuffer(2)
|
||||
.SetWorkerConnectorSize(2)
|
||||
.SetNumWorkers(2);
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
schema->LoadSchemaFile(schema_path_, {});
|
||||
builder.SetDataSchema(std::move(schema));
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
schema->LoadSchemaFile(schema_path_, {});
|
||||
builder.SetDataSchema(std::move(schema));
|
||||
|
||||
Status rc = builder.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
return my_tfreader_op;
|
||||
}
|
||||
Status rc = builder.Build(&my_tfreader_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
return my_tfreader_op;
|
||||
}
|
||||
|
||||
std::shared_ptr<ExecutionTree> my_tree_;
|
||||
|
||||
std::shared_ptr<ExecutionTree> my_tree_;
|
||||
private:
|
||||
std::string dataset_path_;
|
||||
std::string schema_path_;
|
||||
std::string dataset_path_;
|
||||
std::string schema_path_;
|
||||
};
|
||||
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
|
@ -148,10 +155,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) {
|
|||
my_func_list.push_back(my_no_op);
|
||||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"image"})
|
||||
.SetOutColNames({"X"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
builder.SetInColNames({"image"}).SetOutColNames({"X"}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(1);
|
||||
rc = builder.Build(&my_map_op);
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -200,9 +204,9 @@ TEST_F(MindDataTestMapOp, Test3to1) {
|
|||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"image", "A", "B"})
|
||||
.SetOutColNames({"X"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
.SetOutColNames({"X"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
rc = builder.Build(&my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
|
@ -252,10 +256,9 @@ TEST_F(MindDataTestMapOp, Test1to3) {
|
|||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"image"})
|
||||
.SetOutColNames({"X", "Y", "Z"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
|
||||
.SetOutColNames({"X", "Y", "Z"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
|
||||
// ProjectOp
|
||||
std::vector<std::string> columns_to_project = {"X", "Y", "Z", "label", "A", "B"};
|
||||
|
@ -296,19 +299,18 @@ TEST_F(MindDataTestMapOp, Test1to3) {
|
|||
|
||||
// Getting the next row as vector (by position).
|
||||
TensorRow tensor_list;
|
||||
rc =di.FetchNextTensorRow(&tensor_list);
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
// Based on the schema file, create the golden result to compare with.
|
||||
std::vector<DataType::Type> golden_types({DataType::Type::DE_UINT8, DataType::Type::DE_UINT8,
|
||||
DataType::Type::DE_UINT8, DataType::Type::DE_INT64,
|
||||
DataType::Type::DE_FLOAT32, DataType::Type::DE_INT64}
|
||||
);
|
||||
DataType::Type::DE_FLOAT32, DataType::Type::DE_INT64});
|
||||
|
||||
std::vector<uint64_t> golden_ranks({3, 3, 3, 1, 4, 1});
|
||||
|
||||
std::vector<TensorShape> golden_shapes({TensorShape({3, 4, 2}), TensorShape({3, 4, 2}), TensorShape({3, 4, 2}),
|
||||
TensorShape({7}), TensorShape({1, 13, 14, 12}), TensorShape({9})} );
|
||||
TensorShape({7}), TensorShape({1, 13, 14, 12}), TensorShape({9})});
|
||||
|
||||
while (!tensor_list.empty()) {
|
||||
for (uint32_t i = 0; i < tensor_list.size(); i++) {
|
||||
|
@ -343,9 +345,9 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) {
|
|||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"image", "A", "B"})
|
||||
.SetOutColNames({"X", "Y", "Z"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
.SetOutColNames({"X", "Y", "Z"})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(1);
|
||||
rc = builder.Build(&my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
|
@ -405,10 +407,7 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
|
|||
|
||||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"label"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(5);
|
||||
builder.SetInColNames({"label"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(5);
|
||||
rc = builder.Build(&my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
|
@ -440,7 +439,6 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) {
|
|||
MS_LOG(INFO) << "row_count: " << row_count << ".";
|
||||
rc = di.FetchNextTensorRow(&tensor_list);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
}
|
||||
ASSERT_EQ(row_count, 10 * num_repeats);
|
||||
}
|
||||
|
@ -467,10 +465,7 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) {
|
|||
|
||||
std::shared_ptr<MapOp> my_map_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"label"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(50);
|
||||
builder.SetInColNames({"label"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(50);
|
||||
rc = builder.Build(&my_map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_op);
|
||||
|
@ -536,25 +531,18 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) {
|
|||
|
||||
std::shared_ptr<MapOp> my_map_decode_op;
|
||||
MapOp::Builder builder;
|
||||
builder.SetInColNames({"image"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(std::move(my_func_list))
|
||||
.SetNumWorkers(4);
|
||||
builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(4);
|
||||
rc = builder.Build(&my_map_decode_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_decode_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
||||
auto resize_op = std::make_shared<ResizeOp>(300, 300);
|
||||
std::vector<std::shared_ptr<TensorOp>> my_func_list2;
|
||||
my_func_list2.push_back(resize_op);
|
||||
std::shared_ptr<MapOp> my_map_resize_op;
|
||||
MapOp::Builder builder2;
|
||||
builder2.SetInColNames({"image"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(std::move(my_func_list2))
|
||||
.SetNumWorkers(5);
|
||||
builder2.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list2)).SetNumWorkers(5);
|
||||
rc = builder2.Build(&my_map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = my_tree_->AssociateNode(my_map_resize_op);
|
||||
|
@ -610,10 +598,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
|
|||
|
||||
std::shared_ptr<MapOp> map_decode_map;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({"image"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(func_list)
|
||||
.SetNumWorkers(4);
|
||||
map_decode_builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_decode_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -622,10 +607,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
|
|||
func_list2.push_back(resize_op);
|
||||
std::shared_ptr<MapOp> map_resize_op;
|
||||
MapOp::Builder map_resize_builder;
|
||||
map_resize_builder.SetInColNames({"image"})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(func_list2)
|
||||
.SetNumWorkers(5);
|
||||
map_resize_builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(func_list2).SetNumWorkers(5);
|
||||
rc = map_resize_builder.Build(&map_resize_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -704,7 +686,6 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) {
|
|||
EXPECT_EQ(result, result2);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) {
|
||||
Status rc;
|
||||
MS_LOG(INFO) << "Doing ImageFolder_Decode_Repeat_Resize_NoInputColumns.";
|
||||
|
@ -722,10 +703,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) {
|
|||
|
||||
std::shared_ptr<MapOp> map_decode_map;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({})
|
||||
.SetOutColNames({})
|
||||
.SetTensorFuncs(func_list)
|
||||
.SetNumWorkers(4);
|
||||
map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_decode_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -761,3 +739,5 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) {
|
|||
}
|
||||
EXPECT_TRUE(i == 88);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include "dataset/kernels/image/decode_op.h"
|
||||
#include "dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestTensorOpFusionPass : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestTensorOpFusionPass() = default;
|
||||
void SetUp() override { GlobalInit(); }
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_disabled) {
|
||||
MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion";
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
auto rcar_op = std::make_shared<RandomCropAndResizeOp>();
|
||||
auto decode_op = std::make_shared<DecodeOp>();
|
||||
Status rc;
|
||||
std::vector<std::shared_ptr<TensorOp>> func_list;
|
||||
func_list.push_back(decode_op);
|
||||
func_list.push_back(rcar_op);
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
tree = Build({ImageFolder(16, 2, 32, "./", false), map_op});
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc);
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
auto it = tree->begin();
|
||||
++it;
|
||||
auto *m_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(m_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
EXPECT_EQ((*func_it)->Name(), kDecodeOp);
|
||||
++func_it;
|
||||
EXPECT_EQ((*func_it)->Name(), kRandomCropAndResizeOp);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_enabled) {
|
||||
MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion";
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
auto rcar_op = std::make_shared<RandomCropAndResizeOp>();
|
||||
auto decode_op = std::make_shared<DecodeOp>();
|
||||
Status rc;
|
||||
std::vector<std::shared_ptr<TensorOp>> func_list;
|
||||
func_list.push_back(decode_op);
|
||||
func_list.push_back(rcar_op);
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
tree = Build({ImageFolder(16, 2, 32, "./", false), map_op});
|
||||
rc = tree->SetOptimize(true);
|
||||
EXPECT_TRUE(rc);
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
auto it = tree->begin();
|
||||
++it;
|
||||
auto *m_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(m_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
EXPECT_EQ((*func_it)->Name(), kRandomCropDecodeResizeOp);
|
||||
EXPECT_EQ(++func_it, tfuncs.end());
|
||||
}
|
Loading…
Reference in New Issue