forked from mindspore-Ecosystem/mindspore
Move TensorOperation and validator functions down to /kernels/ir
This commit is contained in:
parent
e489b67a3a
commit
6c02670116
|
@ -96,6 +96,7 @@ add_dependencies(cpp-API core)
|
|||
add_dependencies(engine-ir-datasetops core)
|
||||
add_dependencies(engine-ir-datasetops-source core)
|
||||
add_dependencies(engine-ir-cache core)
|
||||
add_dependencies(kernels-ir core)
|
||||
|
||||
if(ENABLE_ACL)
|
||||
add_dependencies(kernels-dvpp-image core dvpp-utils)
|
||||
|
@ -144,6 +145,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:engine>
|
||||
$<TARGET_OBJECTS:text>
|
||||
$<TARGET_OBJECTS:text-kernels>
|
||||
$<TARGET_OBJECTS:kernels-ir>
|
||||
)
|
||||
|
||||
if(ENABLE_ACL)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
// Kernel data headers (in alphabetical order)
|
||||
#include "minddata/dataset/kernels/data/compose_op.h"
|
||||
|
@ -30,180 +31,6 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/* ####################################### Validator Functions ############################################ */
|
||||
Status ValidateProbability(const std::string &op_name, const float probability) {
|
||||
if (probability < 0.0 || probability > 1.0) {
|
||||
std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
|
||||
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
|
||||
std::string err_msg =
|
||||
op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// Note that fill_value need to be in range [0, 255],
|
||||
// but we omit the check since its type is uint8_t
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
|
||||
const std::vector<float> &attr, const std::vector<float> &range) {
|
||||
if (attr.empty() || attr.size() > 2) {
|
||||
std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (auto &attr_val : attr) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
|
||||
}
|
||||
if (attr.size() == 2 && (attr[0] > attr[1])) {
|
||||
std::string err_msg = op_name + ":" + attr_name +
|
||||
" lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
|
||||
", ub: " + std::to_string(attr[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
|
||||
const std::vector<float> &std) {
|
||||
if (mean.size() != 3) {
|
||||
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (std.size() != 3) {
|
||||
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// check std/mean value
|
||||
for (int32_t i = 0; i < std.size(); ++i) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
|
||||
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
|
||||
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (const auto &pad_val : padding) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec) {
|
||||
for (const auto &vec_val : vec) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec) {
|
||||
for (const auto &vec_val : vec) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
|
||||
if (size.empty() || size.size() > 2) {
|
||||
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (const auto &size_val : size) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
|
||||
if (scale.size() != 2) {
|
||||
std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
|
||||
if (scale[1] < scale[0]) {
|
||||
std::string err_msg = op_name + ": scale must be in the format of (min, max).";
|
||||
MS_LOG(ERROR) << op_name + ": scale must be in the format of (min, max), but got: " << scale;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
|
||||
if (ratio.size() != 2) {
|
||||
std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true));
|
||||
if (ratio[1] < ratio[0]) {
|
||||
std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
|
||||
MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorTransforms(const std::string &op_name,
|
||||
const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
|
||||
if (transforms.empty()) {
|
||||
std::string err_msg = op_name + ": transform list must not be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (int32_t i = 0; i < transforms.size(); ++i) {
|
||||
if (transforms[i] == nullptr) {
|
||||
std::string err_msg =
|
||||
op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
|
||||
|
||||
// Transform operations for data.
|
||||
namespace transforms {
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#endif
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
// Kernel image headers (in alphabetical order)
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/auto_contrast_op.h"
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/constants.h"
|
||||
|
||||
// (TEMPORARY) will be removed when Tensor op ir moved down
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_
|
||||
#define INCLUDE_NLOHMANN_JSON_FWD_HPP_
|
||||
namespace nlohmann {
|
||||
|
@ -45,8 +48,6 @@ using json = basic_json<>;
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class TensorOp;
|
||||
|
||||
// Char arrays storing name of corresponding classes (in alphabetical order)
|
||||
constexpr char kComposeOperation[] = "Compose";
|
||||
constexpr char kDuplicateOperation[] = "Duplicate";
|
||||
|
@ -58,108 +59,6 @@ constexpr char kRandomSelectSubpolicyOperation[] = "RandomSelectSubpolicy";
|
|||
constexpr char kTypeCastOperation[] = "TypeCast";
|
||||
constexpr char kUniqueOperation[] = "Unique";
|
||||
|
||||
// Abstract class to represent a dataset in the data pipeline.
|
||||
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TensorOperation() : random_op_(false) {}
|
||||
|
||||
/// \brief Constructor
|
||||
explicit TensorOperation(bool random) : random_op_(random) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~TensorOperation() = default;
|
||||
|
||||
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
|
||||
/// \return shared pointer to the newly created TensorOp.
|
||||
virtual std::shared_ptr<TensorOp> Build() = 0;
|
||||
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief Check whether the operation is deterministic.
|
||||
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
|
||||
bool IsRandomOp() const { return random_op_; }
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
bool random_op_;
|
||||
};
|
||||
|
||||
// Helper function to validate probability
|
||||
Status ValidateProbability(const std::string &op_name, const float probability);
|
||||
|
||||
// Helper function to positive int scalar
|
||||
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
|
||||
|
||||
// Helper function to positive float scalar
|
||||
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
|
||||
|
||||
// Helper function to validate scalar
|
||||
template <typename T>
|
||||
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
|
||||
const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
|
||||
if (range.empty() || range.size() > 2) {
|
||||
std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
|
||||
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
|
||||
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
|
||||
", got: " + std::to_string(scalar);
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
if (range.size() == 2) {
|
||||
if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
|
||||
std::string left_bracket = left_open_interval ? "(" : "[";
|
||||
std::string right_bracket = right_open_interval ? ")" : "]";
|
||||
std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
|
||||
std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
|
||||
", got: " + std::to_string(scalar);
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to validate color attribute
|
||||
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
|
||||
const std::vector<float> &attr, const std::vector<float> &range);
|
||||
|
||||
// Helper function to validate fill value
|
||||
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
|
||||
|
||||
// Helper function to validate mean/std value
|
||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
|
||||
|
||||
// Helper function to validate padding
|
||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
|
||||
|
||||
// Helper function to validate positive value
|
||||
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
|
||||
|
||||
// Helper function to validate non-negative value
|
||||
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec);
|
||||
|
||||
// Helper function to validate size of size
|
||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
|
||||
|
||||
// Helper function to validate scale
|
||||
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
|
||||
|
||||
// Helper function to validate ratio
|
||||
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
|
||||
|
||||
// Helper function to validate transforms
|
||||
Status ValidateVectorTransforms(const std::string &op_name,
|
||||
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
||||
|
||||
// Helper function to compare float value
|
||||
bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f);
|
||||
|
||||
// Transform operations for performing data transformation.
|
||||
namespace transforms {
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_subdirectory(image)
|
||||
add_subdirectory(data)
|
||||
add_subdirectory(ir)
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if(ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
set(DATASET_KERNELS_IR_SRC_FILES
|
||||
validators.cc
|
||||
)
|
||||
|
||||
add_library(kernels-ir OBJECT ${DATASET_KERNELS_IR_SRC_FILES})
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Abstract class to represent a dataset in the data pipeline.
|
||||
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TensorOperation() : random_op_(false) {}
|
||||
|
||||
/// \brief Constructor
|
||||
explicit TensorOperation(bool random) : random_op_(random) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~TensorOperation() = default;
|
||||
|
||||
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
|
||||
/// \return shared pointer to the newly created TensorOp.
|
||||
virtual std::shared_ptr<TensorOp> Build() = 0;
|
||||
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief Check whether the operation is deterministic.
|
||||
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
|
||||
bool IsRandomOp() const { return random_op_; }
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
bool random_op_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_
|
|
@ -0,0 +1,195 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/* ####################################### Validator Functions ############################################ */
|
||||
Status ValidateProbability(const std::string &op_name, const float probability) {
|
||||
if (probability < 0.0 || probability > 1.0) {
|
||||
std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
|
||||
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
|
||||
std::string err_msg =
|
||||
op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// Note that fill_value need to be in range [0, 255],
|
||||
// but we omit the check since its type is uint8_t
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
|
||||
const std::vector<float> &attr, const std::vector<float> &range) {
|
||||
if (attr.empty() || attr.size() > 2) {
|
||||
std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (auto &attr_val : attr) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
|
||||
}
|
||||
if (attr.size() == 2 && (attr[0] > attr[1])) {
|
||||
std::string err_msg = op_name + ":" + attr_name +
|
||||
" lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
|
||||
", ub: " + std::to_string(attr[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
|
||||
const std::vector<float> &std) {
|
||||
if (mean.size() != 3) {
|
||||
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (std.size() != 3) {
|
||||
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// check std/mean value
|
||||
for (int32_t i = 0; i < std.size(); ++i) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
|
||||
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
|
||||
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (const auto &pad_val : padding) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec) {
|
||||
for (const auto &vec_val : vec) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec) {
|
||||
for (const auto &vec_val : vec) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
|
||||
if (size.empty() || size.size() > 2) {
|
||||
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (const auto &size_val : size) {
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
|
||||
if (scale.size() != 2) {
|
||||
std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
|
||||
if (scale[1] < scale[0]) {
|
||||
std::string err_msg = op_name + ": scale must be in the format of (min, max).";
|
||||
MS_LOG(ERROR) << op_name + ": scale must be in the format of (min, max), but got: " << scale;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
|
||||
if (ratio.size() != 2) {
|
||||
std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true));
|
||||
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true));
|
||||
if (ratio[1] < ratio[0]) {
|
||||
std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
|
||||
MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateVectorTransforms(const std::string &op_name,
|
||||
const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
|
||||
if (transforms.empty()) {
|
||||
std::string err_msg = op_name + ": transform list must not be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (int32_t i = 0; i < transforms.size(); ++i) {
|
||||
if (transforms[i] == nullptr) {
|
||||
std::string err_msg =
|
||||
op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class TensorOperation;
|
||||
|
||||
// Helper function to validate probability
|
||||
Status ValidateProbability(const std::string &op_name, const float probability);
|
||||
|
||||
// Helper function to positive int scalar
|
||||
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
|
||||
|
||||
// Helper function to positive float scalar
|
||||
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
|
||||
|
||||
// Helper function to validate scalar
|
||||
template <typename T>
|
||||
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
|
||||
const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
|
||||
if (range.empty() || range.size() > 2) {
|
||||
std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
|
||||
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
|
||||
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
|
||||
", got: " + std::to_string(scalar);
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
if (range.size() == 2) {
|
||||
if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
|
||||
std::string left_bracket = left_open_interval ? "(" : "[";
|
||||
std::string right_bracket = right_open_interval ? ")" : "]";
|
||||
std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
|
||||
std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
|
||||
", got: " + std::to_string(scalar);
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to validate color attribute
|
||||
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
|
||||
const std::vector<float> &attr, const std::vector<float> &range);
|
||||
|
||||
// Helper function to validate fill value
|
||||
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
|
||||
|
||||
// Helper function to validate mean/std value
|
||||
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
|
||||
|
||||
// Helper function to validate padding
|
||||
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
|
||||
|
||||
// Helper function to validate positive value
|
||||
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
|
||||
|
||||
// Helper function to validate non-negative value
|
||||
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
|
||||
const std::vector<int32_t> &vec);
|
||||
|
||||
// Helper function to validate size of size
|
||||
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
|
||||
|
||||
// Helper function to validate scale
|
||||
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
|
||||
|
||||
// Helper function to validate ratio
|
||||
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
|
||||
|
||||
// Helper function to validate transforms
|
||||
Status ValidateVectorTransforms(const std::string &op_name,
|
||||
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
|
||||
|
||||
// Helper function to compare float value
|
||||
bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
|
|
@ -99,6 +99,7 @@ AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/util MINDDATA_UTIL_SRC_FILES)
|
|||
|
||||
AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image/lite_cv MINDDATA_KERNELS_IMAGE_LITE_CV_FILES)
|
||||
|
||||
AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/ir MINDDATA_KERNELS_IR_SRC_FILES)
|
||||
|
||||
if(BUILD_MINDDATA STREQUAL "full")
|
||||
include_directories("${MINDDATA_DIR}/kernels/image")
|
||||
|
@ -200,6 +201,7 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
${MINDDATA_DIR}/kernels/data/random_choice_op.cc
|
||||
${MINDDATA_DIR}/kernels/data/type_cast_op.cc
|
||||
${MINDDATA_DIR}/kernels/image/exif_utils.cc
|
||||
${MINDDATA_DIR}/kernels/ir/validators.cc
|
||||
${MINDDATA_DIR}/callback/callback_manager.cc
|
||||
${MINDDATA_DIR}/util/task_manager.cc
|
||||
${MINDDATA_DIR}/util/services.cc
|
||||
|
@ -294,6 +296,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
|
|||
${CORE_DIR}/utils/ms_utils.cc
|
||||
${MINDDATA_TODAPI_SRC}
|
||||
${MINDSPORE_LITE_CXXAPI_SRC}
|
||||
${MINDDATA_DIR}/kernels/ir/validators.cc
|
||||
)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
@ -387,6 +390,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
|
|||
${MINDDATA_DIR}/api/transforms.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/log_adapter.cc
|
||||
${CORE_DIR}/utils/ms_utils.cc
|
||||
${MINDDATA_DIR}/kernels/ir/validators.cc
|
||||
)
|
||||
|
||||
target_link_libraries(minddata-lite
|
||||
|
|
Loading…
Reference in New Issue