Move TensorOperation and validator functions down to /kernels/ir

This commit is contained in:
TinaMengtingZhang 2021-02-08 12:10:42 -05:00
parent e489b67a3a
commit 6c02670116
10 changed files with 382 additions and 279 deletions

View File

@ -96,6 +96,7 @@ add_dependencies(cpp-API core)
add_dependencies(engine-ir-datasetops core)
add_dependencies(engine-ir-datasetops-source core)
add_dependencies(engine-ir-cache core)
add_dependencies(kernels-ir core)
if(ENABLE_ACL)
add_dependencies(kernels-dvpp-image core dvpp-utils)
@ -144,6 +145,7 @@ set(submodules
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
$<TARGET_OBJECTS:kernels-ir>
)
if(ENABLE_ACL)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
*/
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/kernels/ir/validators.h"
// Kernel data headers (in alphabetical order)
#include "minddata/dataset/kernels/data/compose_op.h"
@ -30,180 +31,6 @@
namespace mindspore {
namespace dataset {
/* ####################################### Validator Functions ############################################ */
Status ValidateProbability(const std::string &op_name, const float probability) {
if (probability < 0.0 || probability > 1.0) {
std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
std::string err_msg =
op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Note that fill_value need to be in range [0, 255],
// but we omit the check since its type is uint8_t
return Status::OK();
}
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
const std::vector<float> &attr, const std::vector<float> &range) {
if (attr.empty() || attr.size() > 2) {
std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (auto &attr_val : attr) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
}
if (attr.size() == 2 && (attr[0] > attr[1])) {
std::string err_msg = op_name + ":" + attr_name +
" lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
", ub: " + std::to_string(attr[1]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
const std::vector<float> &std) {
if (mean.size() != 3) {
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std.size() != 3) {
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// check std/mean value
for (int32_t i = 0; i < std.size(); ++i) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
}
return Status::OK();
}
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &pad_val : padding) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
}
return Status::OK();
}
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec) {
for (const auto &vec_val : vec) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
}
return Status::OK();
}
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec) {
for (const auto &vec_val : vec) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
}
return Status::OK();
}
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
if (size.empty() || size.size() > 2) {
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &size_val : size) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
}
return Status::OK();
}
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
if (scale.size() != 2) {
std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
if (scale[1] < scale[0]) {
std::string err_msg = op_name + ": scale must be in the format of (min, max).";
MS_LOG(ERROR) << op_name + ": scale must be in the format of (min, max), but got: " << scale;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
if (ratio.size() != 2) {
std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true));
if (ratio[1] < ratio[0]) {
std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorTransforms(const std::string &op_name,
const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
if (transforms.empty()) {
std::string err_msg = op_name + ": transform list must not be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (int32_t i = 0; i < transforms.size(); ++i) {
if (transforms[i] == nullptr) {
std::string err_msg =
op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
// Transform operations for data.
namespace transforms {

View File

@ -19,6 +19,8 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#endif
#include "minddata/dataset/kernels/ir/validators.h"
// Kernel image headers (in alphabetical order)
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/auto_contrast_op.h"

View File

@ -25,6 +25,9 @@
#include "include/api/status.h"
#include "minddata/dataset/include/constants.h"
// (TEMPORARY) will be removed when Tensor op ir moved down
#include "minddata/dataset/kernels/ir/tensor_operation.h"
#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_
#define INCLUDE_NLOHMANN_JSON_FWD_HPP_
namespace nlohmann {
@ -45,8 +48,6 @@ using json = basic_json<>;
namespace mindspore {
namespace dataset {
class TensorOp;
// Char arrays storing name of corresponding classes (in alphabetical order)
constexpr char kComposeOperation[] = "Compose";
constexpr char kDuplicateOperation[] = "Duplicate";
@ -58,108 +59,6 @@ constexpr char kRandomSelectSubpolicyOperation[] = "RandomSelectSubpolicy";
constexpr char kTypeCastOperation[] = "TypeCast";
constexpr char kUniqueOperation[] = "Unique";
// Abstract class to represent a dataset in the data pipeline.
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
public:
/// \brief Constructor
TensorOperation() : random_op_(false) {}
/// \brief Constructor
explicit TensorOperation(bool random) : random_op_(random) {}
/// \brief Destructor
~TensorOperation() = default;
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
/// \return shared pointer to the newly created TensorOp.
virtual std::shared_ptr<TensorOp> Build() = 0;
virtual Status ValidateParams() = 0;
virtual std::string Name() const = 0;
/// \brief Check whether the operation is deterministic.
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
bool IsRandomOp() const { return random_op_; }
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
protected:
bool random_op_;
};
// Helper function to validate probability
Status ValidateProbability(const std::string &op_name, const float probability);
// Helper function to positive int scalar
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
// Helper function to positive float scalar
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to validate scalar
template <typename T>
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
if (range.empty() || range.size() > 2) {
std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
", got: " + std::to_string(scalar);
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if (range.size() == 2) {
if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
std::string left_bracket = left_open_interval ? "(" : "[";
std::string right_bracket = right_open_interval ? ")" : "]";
std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
", got: " + std::to_string(scalar);
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
}
return Status::OK();
}
// Helper function to validate color attribute
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
const std::vector<float> &attr, const std::vector<float> &range);
// Helper function to validate fill value
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
// Helper function to validate mean/std value
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
// Helper function to validate padding
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
// Helper function to validate positive value
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
// Helper function to validate non-negative value
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec);
// Helper function to validate size of size
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
// Helper function to validate scale
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
// Helper function to validate ratio
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
// Helper function to validate transforms
Status ValidateVectorTransforms(const std::string &op_name,
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
// Helper function to compare float value
bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f);
// Transform operations for performing data transformation.
namespace transforms {

View File

@ -1,5 +1,6 @@
add_subdirectory(image)
add_subdirectory(data)
add_subdirectory(ir)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if(ENABLE_PYTHON)

View File

@ -0,0 +1,8 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
set(DATASET_KERNELS_IR_SRC_FILES
validators.cc
)
add_library(kernels-ir OBJECT ${DATASET_KERNELS_IR_SRC_FILES})

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_
#include <memory>
#include <string>
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Abstract class to represent a dataset in the data pipeline.
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
public:
/// \brief Constructor
TensorOperation() : random_op_(false) {}
/// \brief Constructor
explicit TensorOperation(bool random) : random_op_(random) {}
/// \brief Destructor
~TensorOperation() = default;
/// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object.
/// \return shared pointer to the newly created TensorOp.
virtual std::shared_ptr<TensorOp> Build() = 0;
virtual Status ValidateParams() = 0;
virtual std::string Name() const = 0;
/// \brief Check whether the operation is deterministic.
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
bool IsRandomOp() const { return random_op_; }
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
protected:
bool random_op_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_TENSOR_OPERATION_H_

View File

@ -0,0 +1,195 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/ir/validators.h"
namespace mindspore {
namespace dataset {
/* ####################################### Validator Functions ############################################ */
Status ValidateProbability(const std::string &op_name, const float probability) {
if (probability < 0.0 || probability > 1.0) {
std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
std::string err_msg =
op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Note that fill_value need to be in range [0, 255],
// but we omit the check since its type is uint8_t
return Status::OK();
}
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
const std::vector<float> &attr, const std::vector<float> &range) {
if (attr.empty() || attr.size() > 2) {
std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (auto &attr_val : attr) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
}
if (attr.size() == 2 && (attr[0] > attr[1])) {
std::string err_msg = op_name + ":" + attr_name +
" lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
", ub: " + std::to_string(attr[1]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
const std::vector<float> &std) {
if (mean.size() != 3) {
std::string err_msg = op_name + ": mean expecting size 3, got size: " + std::to_string(mean.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std.size() != 3) {
std::string err_msg = op_name + ": std expecting size 3, got size: " + std::to_string(std.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// check std/mean value
for (int32_t i = 0; i < std.size(); ++i) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
}
return Status::OK();
}
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &pad_val : padding) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
}
return Status::OK();
}
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec) {
for (const auto &vec_val : vec) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
}
return Status::OK();
}
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec) {
for (const auto &vec_val : vec) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
}
return Status::OK();
}
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
if (size.empty() || size.size() > 2) {
std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (const auto &size_val : size) {
RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
}
return Status::OK();
}
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
if (scale.size() != 2) {
std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
if (scale[1] < scale[0]) {
std::string err_msg = op_name + ": scale must be in the format of (min, max).";
MS_LOG(ERROR) << op_name + ": scale must be in the format of (min, max), but got: " << scale;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
if (ratio.size() != 2) {
std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true));
if (ratio[1] < ratio[0]) {
std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
Status ValidateVectorTransforms(const std::string &op_name,
const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
if (transforms.empty()) {
std::string err_msg = op_name + ": transform list must not be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (int32_t i = 0; i < transforms.size(); ++i) {
if (transforms[i] == nullptr) {
std::string err_msg =
op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class TensorOperation;
// Helper function to validate probability
Status ValidateProbability(const std::string &op_name, const float probability);
// Helper function to positive int scalar
Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
// Helper function to positive float scalar
Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to validate scalar
template <typename T>
Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
if (range.empty() || range.size() > 2) {
std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
", got: " + std::to_string(scalar);
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if (range.size() == 2) {
if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
std::string left_bracket = left_open_interval ? "(" : "[";
std::string right_bracket = right_open_interval ? ")" : "]";
std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
", got: " + std::to_string(scalar);
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
}
return Status::OK();
}
// Helper function to validate color attribute
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
const std::vector<float> &attr, const std::vector<float> &range);
// Helper function to validate fill value
Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
// Helper function to validate mean/std value
Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
// Helper function to validate padding
Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
// Helper function to validate positive value
Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
// Helper function to validate non-negative value
Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
const std::vector<int32_t> &vec);
// Helper function to validate size of size
Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
// Helper function to validate scale
Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
// Helper function to validate ratio
Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
// Helper function to validate transforms
Status ValidateVectorTransforms(const std::string &op_name,
const std::vector<std::shared_ptr<TensorOperation>> &transforms);
// Helper function to compare float value
bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_

View File

@ -99,6 +99,7 @@ AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/util MINDDATA_UTIL_SRC_FILES)
AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/image/lite_cv MINDDATA_KERNELS_IMAGE_LITE_CV_FILES)
AUX_SOURCE_DIRECTORY(${MINDDATA_DIR}/kernels/ir MINDDATA_KERNELS_IR_SRC_FILES)
if(BUILD_MINDDATA STREQUAL "full")
include_directories("${MINDDATA_DIR}/kernels/image")
@ -200,6 +201,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/kernels/data/random_choice_op.cc
${MINDDATA_DIR}/kernels/data/type_cast_op.cc
${MINDDATA_DIR}/kernels/image/exif_utils.cc
${MINDDATA_DIR}/kernels/ir/validators.cc
${MINDDATA_DIR}/callback/callback_manager.cc
${MINDDATA_DIR}/util/task_manager.cc
${MINDDATA_DIR}/util/services.cc
@ -294,6 +296,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
${CORE_DIR}/utils/ms_utils.cc
${MINDDATA_TODAPI_SRC}
${MINDSPORE_LITE_CXXAPI_SRC}
${MINDDATA_DIR}/kernels/ir/validators.cc
)
find_package(Threads REQUIRED)
@ -387,6 +390,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
${MINDDATA_DIR}/api/transforms.cc
${CMAKE_CURRENT_SOURCE_DIR}/../src/common/log_adapter.cc
${CORE_DIR}/utils/ms_utils.cc
${MINDDATA_DIR}/kernels/ir/validators.cc
)
target_link_libraries(minddata-lite