forked from mindspore-Ecosystem/mindspore
C++ API: ValidateParams support for TensorOps; plus minor dataset fixes
This commit is contained in:
parent
79b5fee04b
commit
051bc60edb
|
@ -48,17 +48,21 @@ LookupOperation::LookupOperation(const std::shared_ptr<Vocab> &vocab, const std:
|
|||
const DataType &data_type)
|
||||
: vocab_(vocab), unknown_token_(unknown_token), default_id_(Vocab::kNoTokenExists), data_type_(data_type) {}
|
||||
|
||||
bool LookupOperation::ValidateParams() {
|
||||
Status LookupOperation::ValidateParams() {
|
||||
if (vocab_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Lookup: vocab object type is incorrect or null.";
|
||||
return false;
|
||||
std::string err_msg = "Lookup: vocab object type is incorrect or null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
default_id_ = vocab_->Lookup(unknown_token_);
|
||||
if (default_id_ == Vocab::kNoTokenExists) {
|
||||
MS_LOG(ERROR) << "Lookup: " << unknown_token_ << " doesn't exist in vocab.";
|
||||
return false;
|
||||
std::string err_msg = "Lookup: " + unknown_token_ + " doesn't exist in vocab.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> LookupOperation::Build() {
|
||||
|
|
|
@ -62,13 +62,15 @@ std::shared_ptr<TypeCastOperation> TypeCast(std::string data_type) {
|
|||
// OneHotOperation
|
||||
OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
|
||||
|
||||
bool OneHotOperation::ValidateParams() {
|
||||
Status OneHotOperation::ValidateParams() {
|
||||
if (num_classes_ < 0) {
|
||||
MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_;
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"OneHot: Number of classes cannot be negative. Number of classes: " + std::to_string(num_classes_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
|
||||
|
@ -76,17 +78,18 @@ std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<One
|
|||
// TypeCastOperation
|
||||
TypeCastOperation::TypeCastOperation(std::string data_type) : data_type_(data_type) {}
|
||||
|
||||
bool TypeCastOperation::ValidateParams() {
|
||||
Status TypeCastOperation::ValidateParams() {
|
||||
std::vector<std::string> predefine_type = {"bool", "int8", "uint8", "int16", "uint16", "int32", "uint32",
|
||||
"int64", "uint64", "float16", "float32", "float64", "string"};
|
||||
auto itr = std::find(predefine_type.begin(), predefine_type.end(), data_type_);
|
||||
if (itr == predefine_type.end()) {
|
||||
MS_LOG(ERROR) << "TypeCast: Only support type bool, int8, uint8, int16, uint16, int32, uint32, "
|
||||
std::string err_msg = "TypeCast: Invalid data type: " + data_type_;
|
||||
MS_LOG(ERROR) << "TypeCast: Only supports data type bool, int8, uint8, int16, uint16, int32, uint32, "
|
||||
<< "int64, uint64, float16, float32, float64, string, but got " << data_type_;
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
|
||||
|
|
|
@ -353,11 +353,17 @@ std::shared_ptr<UniformAugOperation> UniformAugment(std::vector<std::shared_ptr<
|
|||
}
|
||||
|
||||
/* ####################################### Validator Functions ############################################ */
|
||||
bool CheckVectorPositive(const std::vector<int32_t> &size) {
|
||||
Status ValidateVectorPositive(const std::string &dataset_name, const std::vector<int32_t> &size) {
|
||||
for (int i = 0; i < size.size(); ++i) {
|
||||
if (size[i] <= 0) return false;
|
||||
if (size[i] <= 0) {
|
||||
std::string err_msg =
|
||||
dataset_name + ": Non-positive size value: " + std::to_string(size[i]) + " at element: " + std::to_string(i);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) { return (std::fabs(a - b) < epsilon); }
|
||||
|
@ -369,23 +375,26 @@ bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) { r
|
|||
// CenterCropOperation
|
||||
CenterCropOperation::CenterCropOperation(std::vector<int32_t> size) : size_(size) {}
|
||||
|
||||
bool CenterCropOperation::ValidateParams() {
|
||||
Status CenterCropOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size.";
|
||||
return false;
|
||||
std::string err_msg = "CenterCrop: size vector has incorrect size.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// We have to limit crop size due to library restrictions, optimized to only iterate over size_ once
|
||||
for (int i = 0; i < size_.size(); ++i) {
|
||||
if (size_[i] <= 0) {
|
||||
MS_LOG(ERROR) << "CenterCrop: invalid size, size must be greater than 0, got: " << size_[i];
|
||||
return false;
|
||||
std::string err_msg = "CenterCrop: invalid size, size must be greater than 0, got: " + std::to_string(size_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (size_[i] == INT_MAX) {
|
||||
MS_LOG(ERROR) << "CenterCrop: invalid size, size too large, got: " << size_[i];
|
||||
return false;
|
||||
std::string err_msg = "CenterCrop: invalid size, size too large, got: " + std::to_string(size_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CenterCropOperation::Build() {
|
||||
|
@ -405,35 +414,41 @@ std::shared_ptr<TensorOp> CenterCropOperation::Build() {
|
|||
CropOperation::CropOperation(std::vector<int32_t> coordinates, std::vector<int32_t> size)
|
||||
: coordinates_(coordinates), size_(size) {}
|
||||
|
||||
bool CropOperation::ValidateParams() {
|
||||
Status CropOperation::ValidateParams() {
|
||||
// Do some input validation.
|
||||
if (coordinates_.size() != 2) {
|
||||
MS_LOG(ERROR) << "Crop: coordinates must be a vector of two values";
|
||||
return false;
|
||||
std::string err_msg = "Crop: coordinates must be a vector of two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// we don't check the coordinates here because we don't have access to image dimensions
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "Crop: size must be a vector of one or two values";
|
||||
return false;
|
||||
std::string err_msg = "Crop: size must be a vector of one or two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// We have to limit crop size due to library restrictions, optimized to only iterate over size_ once
|
||||
for (int i = 0; i < size_.size(); ++i) {
|
||||
if (size_[i] <= 0) {
|
||||
MS_LOG(ERROR) << "Crop: invalid size, size must be greater than 0, got: " << size_[i];
|
||||
return false;
|
||||
std::string err_msg = "Crop: invalid size, size must be greater than 0, got: " + std::to_string(size_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (size_[i] == INT_MAX) {
|
||||
MS_LOG(ERROR) << "Crop: invalid size, size too large, got: " << size_[i];
|
||||
return false;
|
||||
std::string err_msg = "Crop: invalid size, size too large, got: " + std::to_string(size_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < coordinates_.size(); ++j) {
|
||||
if (coordinates_[j] < 0) {
|
||||
MS_LOG(ERROR) << "Crop: invalid coordinates, coordinates must be greater than 0, got: " << coordinates_[j];
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"Crop: invalid coordinates, coordinates must be greater than 0, got: " + std::to_string(coordinates_[j]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CropOperation::Build() {
|
||||
|
@ -456,16 +471,19 @@ std::shared_ptr<TensorOp> CropOperation::Build() {
|
|||
CutMixBatchOperation::CutMixBatchOperation(ImageBatchFormat image_batch_format, float alpha, float prob)
|
||||
: image_batch_format_(image_batch_format), alpha_(alpha), prob_(prob) {}
|
||||
|
||||
bool CutMixBatchOperation::ValidateParams() {
|
||||
Status CutMixBatchOperation::ValidateParams() {
|
||||
if (alpha_ <= 0) {
|
||||
MS_LOG(ERROR) << "CutMixBatch: alpha must be a positive floating value however it is: " << alpha_;
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"CutMixBatch: alpha must be a positive floating value however it is: " + std::to_string(alpha_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (prob_ < 0 || prob_ > 1) {
|
||||
MS_LOG(ERROR) << "CutMixBatch: Probability has to be between 0 and 1.";
|
||||
return false;
|
||||
std::string err_msg = "CutMixBatch: Probability has to be between 0 and 1.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CutMixBatchOperation::Build() {
|
||||
|
@ -476,16 +494,18 @@ std::shared_ptr<TensorOp> CutMixBatchOperation::Build() {
|
|||
// CutOutOperation
|
||||
CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {}
|
||||
|
||||
bool CutOutOperation::ValidateParams() {
|
||||
Status CutOutOperation::ValidateParams() {
|
||||
if (length_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: length cannot be negative";
|
||||
return false;
|
||||
std::string err_msg = "CutOut: length cannot be negative";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (num_patches_ < 0) {
|
||||
MS_LOG(ERROR) << "CutOut: number of patches cannot be negative";
|
||||
return false;
|
||||
std::string err_msg = "CutOut: number of patches cannot be negative";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> CutOutOperation::Build() {
|
||||
|
@ -496,25 +516,27 @@ std::shared_ptr<TensorOp> CutOutOperation::Build() {
|
|||
// DecodeOperation
|
||||
DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {}
|
||||
|
||||
bool DecodeOperation::ValidateParams() { return true; }
|
||||
Status DecodeOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
|
||||
|
||||
// HwcToChwOperation
|
||||
bool HwcToChwOperation::ValidateParams() { return true; }
|
||||
Status HwcToChwOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> HwcToChwOperation::Build() { return std::make_shared<HwcToChwOp>(); }
|
||||
|
||||
// MixUpOperation
|
||||
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
|
||||
|
||||
bool MixUpBatchOperation::ValidateParams() {
|
||||
Status MixUpBatchOperation::ValidateParams() {
|
||||
if (alpha_ <= 0) {
|
||||
MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"MixUpBatch: alpha must be a positive floating value however it is: " + std::to_string(alpha_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
|
||||
|
@ -522,23 +544,26 @@ std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared
|
|||
// NormalizeOperation
|
||||
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
|
||||
|
||||
bool NormalizeOperation::ValidateParams() {
|
||||
Status NormalizeOperation::ValidateParams() {
|
||||
if (mean_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
|
||||
return false;
|
||||
std::string err_msg = "Normalize: mean vector has incorrect size: " + std::to_string(mean_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (std_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
|
||||
return false;
|
||||
std::string err_msg = "Normalize: std vector has incorrect size: " + std::to_string(std_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// check std value
|
||||
for (int i = 0; i < std_.size(); ++i) {
|
||||
if (std_[i] < 0.0f || std_[i] > 255.0f || CmpFloat(std_[i], 0.0f)) {
|
||||
MS_LOG(ERROR) << "Normalize: std vector has incorrect value: " << std_[i];
|
||||
return false;
|
||||
std::string err_msg = "Normalize: std vector has incorrect value: " + std::to_string(std_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
||||
|
@ -549,17 +574,19 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
|||
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
|
||||
|
||||
bool PadOperation::ValidateParams() {
|
||||
Status PadOperation::ValidateParams() {
|
||||
if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) {
|
||||
MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()";
|
||||
return false;
|
||||
std::string err_msg = "Pad: padding vector has incorrect size: padding.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) {
|
||||
MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
std::string err_msg = "Pad: fill_value vector has incorrect size: fill_value.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> PadOperation::Build() {
|
||||
|
@ -613,88 +640,107 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees
|
|||
interpolation_(interpolation),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
bool RandomAffineOperation::ValidateParams() {
|
||||
Status RandomAffineOperation::ValidateParams() {
|
||||
// Degrees
|
||||
if (degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomAffine: degrees expecting size 2, got: degrees.size() = " << degrees_.size();
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"RandomAffine: degrees expecting size 2, got: degrees.size() = " + std::to_string(degrees_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (degrees_[0] > degrees_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of degrees range is greater than maximum: min = " << degrees_[0]
|
||||
<< ", max = " << degrees_[1];
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"RandomAffine: minimum of degrees range is greater than maximum: min = " + std::to_string(degrees_[0]) +
|
||||
", max = " + std::to_string(degrees_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// Translate
|
||||
if (translate_range_.size() != 2 && translate_range_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomAffine: translate_range expecting size 2 or 4, got: translate_range.size() = "
|
||||
<< translate_range_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: translate_range expecting size 2 or 4, got: translate_range.size() = " +
|
||||
std::to_string(translate_range_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_[0] > translate_range_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on x is greater than maximum: min = "
|
||||
<< translate_range_[0] << ", max = " << translate_range_[1];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of translate range on x is greater than maximum: min = " +
|
||||
std::to_string(translate_range_[0]) + ", max = " + std::to_string(translate_range_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_[0] < -1 || translate_range_[0] > 1) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on x is out of range of [-1, 1], value = "
|
||||
<< translate_range_[0];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of translate range on x is out of range of [-1, 1], value = " +
|
||||
std::to_string(translate_range_[0]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_[1] < -1 || translate_range_[1] > 1) {
|
||||
MS_LOG(ERROR) << "RandomAffine: maximum of translate range on x is out of range of [-1, 1], value = "
|
||||
<< translate_range_[1];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: maximum of translate range on x is out of range of [-1, 1], value = " +
|
||||
std::to_string(translate_range_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_.size() == 4) {
|
||||
if (translate_range_[2] > translate_range_[3]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on y is greater than maximum: min = "
|
||||
<< translate_range_[2] << ", max = " << translate_range_[3];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of translate range on y is greater than maximum: min = " +
|
||||
std::to_string(translate_range_[2]) + ", max = " + std::to_string(translate_range_[3]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_[2] < -1 || translate_range_[2] > 1) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of translate range on y is out of range of [-1, 1], value = "
|
||||
<< translate_range_[2];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of translate range on y is out of range of [-1, 1], value = " +
|
||||
std::to_string(translate_range_[2]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (translate_range_[3] < -1 || translate_range_[3] > 1) {
|
||||
MS_LOG(ERROR) << "RandomAffine: maximum of translate range on y is out of range of [-1, 1], value = "
|
||||
<< translate_range_[3];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: maximum of translate range on y is out of range of [-1, 1], value = " +
|
||||
std::to_string(translate_range_[3]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
// Scale
|
||||
if (scale_range_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomAffine: scale_range vector has incorrect size: scale_range.size() = "
|
||||
<< scale_range_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: scale_range vector has incorrect size: scale_range.size() = " +
|
||||
std::to_string(scale_range_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (scale_range_[0] > scale_range_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of scale range is greater than maximum: min = " << scale_range_[0]
|
||||
<< ", max = " << scale_range_[1];
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"RandomAffine: minimum of scale range is greater than maximum: min = " + std::to_string(scale_range_[0]) +
|
||||
", max = " + std::to_string(scale_range_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// Shear
|
||||
if (shear_ranges_.size() != 2 && shear_ranges_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomAffine: shear_ranges expecting size 2 or 4, got: shear_ranges.size() = "
|
||||
<< shear_ranges_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: shear_ranges expecting size 2 or 4, got: shear_ranges.size() = " +
|
||||
std::to_string(shear_ranges_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (shear_ranges_[0] > shear_ranges_[1]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of horizontal shear range is greater than maximum: min = "
|
||||
<< shear_ranges_[0] << ", max = " << shear_ranges_[1];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of horizontal shear range is greater than maximum: min = " +
|
||||
std::to_string(shear_ranges_[0]) + ", max = " + std::to_string(shear_ranges_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (shear_ranges_.size() == 4 && shear_ranges_[2] > shear_ranges_[3]) {
|
||||
MS_LOG(ERROR) << "RandomAffine: minimum of vertical shear range is greater than maximum: min = " << shear_ranges_[2]
|
||||
<< ", max = " << scale_range_[3];
|
||||
return false;
|
||||
std::string err_msg = "RandomAffine: minimum of vertical shear range is greater than maximum: min = " +
|
||||
std::to_string(shear_ranges_[2]) + ", max = " + std::to_string(scale_range_[3]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
// Fill Value
|
||||
if (fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomAffine: fill_value vector has incorrect size: fill_value.size() = " << fill_value_.size();
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"RandomAffine: fill_value vector has incorrect size: fill_value.size() = " + std::to_string(fill_value_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
||||
|
@ -712,13 +758,14 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
|||
// RandomColorOperation.
|
||||
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
|
||||
|
||||
bool RandomColorOperation::ValidateParams() {
|
||||
Status RandomColorOperation::ValidateParams() {
|
||||
// Do some input validation.
|
||||
if (t_lb_ > t_ub_) {
|
||||
MS_LOG(ERROR) << "RandomColor: lower bound must be less or equal to upper bound";
|
||||
return false;
|
||||
std::string err_msg = "RandomColor: lower bound must be less or equal to upper bound";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// RandomColorAdjustOperation.
|
||||
|
@ -726,25 +773,29 @@ RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> bright
|
|||
std::vector<float> saturation, std::vector<float> hue)
|
||||
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
|
||||
|
||||
bool RandomColorAdjustOperation::ValidateParams() {
|
||||
Status RandomColorAdjustOperation::ValidateParams() {
|
||||
// Do some input validation.
|
||||
if (brightness_.empty() || brightness_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values";
|
||||
return false;
|
||||
std::string err_msg = "RandomColorAdjust: brightness must be a vector of one or two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (contrast_.empty() || contrast_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values";
|
||||
return false;
|
||||
std::string err_msg = "RandomColorAdjust: contrast must be a vector of one or two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (saturation_.empty() || saturation_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values";
|
||||
return false;
|
||||
std::string err_msg = "RandomColorAdjust: saturation must be a vector of one or two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (hue_.empty() || hue_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values";
|
||||
return false;
|
||||
std::string err_msg = "RandomColorAdjust: hue must be a vector of one or two values";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
|
||||
|
@ -784,22 +835,25 @@ RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<
|
|||
fill_value_(fill_value),
|
||||
padding_mode_(padding_mode) {}
|
||||
|
||||
bool RandomCropOperation::ValidateParams() {
|
||||
Status RandomCropOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomCrop: size vector has incorrect size: " + std::to_string(size_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (padding_.empty() || padding_.size() != 4) {
|
||||
MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomCrop: padding vector has incorrect size: padding.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomCrop: fill_value vector has incorrect size: fill_value.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
||||
|
@ -831,37 +885,43 @@ RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int
|
|||
InterpolationMode interpolation, int32_t max_attempts)
|
||||
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
|
||||
|
||||
bool RandomCropDecodeResizeOperation::ValidateParams() {
|
||||
Status RandomCropDecodeResizeOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: size vector has incorrect size: " + std::to_string(size_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (scale_.empty() || scale_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: scale vector has incorrect size: " << scale_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: scale vector has incorrect size: " + std::to_string(scale_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (scale_[0] > scale_[1]) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: scale should be in (min,max) format. Got (max,min).";
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: scale should be in (min,max) format. Got (max,min).";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (ratio_.empty() || ratio_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: ratio vector has incorrect size: " << ratio_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: ratio vector has incorrect size: " + std::to_string(ratio_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (ratio_[0] > ratio_[1]) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: ratio should be in (min,max) format. Got (max,min).";
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: ratio should be in (min,max) format. Got (max,min).";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (max_attempts_ < 1) {
|
||||
MS_LOG(ERROR) << "RandomCropDecodeResize: max_attempts must be greater than or equal to 1.";
|
||||
return false;
|
||||
std::string err_msg = "RandomCropDecodeResize: max_attempts must be greater than or equal to 1.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
||||
|
@ -888,12 +948,14 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
|||
// RandomHorizontalFlipOperation
|
||||
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomHorizontalFlipOperation::ValidateParams() {
|
||||
Status RandomHorizontalFlipOperation::ValidateParams() {
|
||||
if (probability_ < 0.0 || probability_ > 1.0) {
|
||||
MS_LOG(ERROR) << "RandomHorizontalFlip: probability must be between 0.0 and 1.0.";
|
||||
return false;
|
||||
std::string err_msg = "RandomHorizontalFlip: probability must be between 0.0 and 1.0.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
||||
|
@ -904,25 +966,30 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
|||
// RandomPosterizeOperation
|
||||
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
||||
|
||||
bool RandomPosterizeOperation::ValidateParams() {
|
||||
Status RandomPosterizeOperation::ValidateParams() {
|
||||
if (bit_range_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: bit_range needs to be of size 2 but is of size: " << bit_range_.size();
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"RandomPosterize: bit_range needs to be of size 2 but is of size: " + std::to_string(bit_range_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (bit_range_[0] < 1 || bit_range_[0] > 8) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: min_bit value is out of range [1-8]: " << bit_range_[0];
|
||||
return false;
|
||||
std::string err_msg = "RandomPosterize: min_bit value is out of range [1-8]: " + std::to_string(bit_range_[0]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (bit_range_[1] < 1 || bit_range_[1] > 8) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: max_bit value is out of range [1-8]: " << bit_range_[1];
|
||||
return false;
|
||||
std::string err_msg = "RandomPosterize: max_bit value is out of range [1-8]: " + std::to_string(bit_range_[1]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (bit_range_[1] < bit_range_[0]) {
|
||||
MS_LOG(ERROR) << "RandomPosterize: max_bit value is less than min_bit: max =" << bit_range_[1]
|
||||
<< ", min = " << bit_range_[0];
|
||||
return false;
|
||||
std::string err_msg = "RandomPosterize: max_bit value is less than min_bit: max =" + std::to_string(bit_range_[1]) +
|
||||
", min = " + std::to_string(bit_range_[0]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
|
||||
|
@ -935,28 +1002,32 @@ RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size
|
|||
std::vector<float> ratio, InterpolationMode interpolation,
|
||||
int32_t max_attempts)
|
||||
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
|
||||
bool RandomResizedCropOperation::ValidateParams() {
|
||||
Status RandomResizedCropOperation::ValidateParams() {
|
||||
if (size_.size() != 2 && size_.size() != 1) {
|
||||
MS_LOG(ERROR) << "RandomResizedCrop: size variable must have a length of 1 or 2 but it has a length of: "
|
||||
<< size_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomResizedCrop: size variable must have a length of 1 or 2 but it has a length of: " +
|
||||
std::to_string(size_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (size_[0] < 0 || (size_.size() == 2 && size_[1] < 0)) {
|
||||
std::string err_msg = "RandomResizedCrop: size variable must only contain positive integers.";
|
||||
MS_LOG(ERROR) << "RandomResizedCrop: size variable must only contain positive integers. However, it is: " << size_;
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (scale_.size() != 2 || scale_[1] < scale_[0]) {
|
||||
std::string err_msg = "RandomResizedCrop: scale variable must have a size of two in the format of (min, max).";
|
||||
MS_LOG(ERROR)
|
||||
<< "RandomResizedCrop: scale variable must have a size of two in the format of (min, max). However, it is: "
|
||||
<< scale_;
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (ratio_.size() != 2 || ratio_[1] < ratio_[0]) {
|
||||
std::string err_msg = "RandomResizedCrop: ratio variable must be in the format of (min, max).";
|
||||
MS_LOG(ERROR) << "RandomResizedCrop: ratio variable must be in the format of (min, max). However , it is: "
|
||||
<< ratio_;
|
||||
return false;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomResizedCropOperation::Build() {
|
||||
|
@ -977,20 +1048,23 @@ RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, Int
|
|||
center_(center),
|
||||
fill_value_(fill_value) {}
|
||||
|
||||
bool RandomRotationOperation::ValidateParams() {
|
||||
Status RandomRotationOperation::ValidateParams() {
|
||||
if (degrees_.empty() || degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomRotation: degrees vector has incorrect size: degrees.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (center_.empty() || center_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomRotation: center vector has incorrect size: center.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (fill_value_.empty() || fill_value_.size() != 3) {
|
||||
MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomRotation: fill_value vector has incorrect size: fill_value.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
||||
|
@ -1003,12 +1077,13 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
|||
// Function to create RandomSharpness.
|
||||
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
|
||||
|
||||
bool RandomSharpnessOperation::ValidateParams() {
|
||||
Status RandomSharpnessOperation::ValidateParams() {
|
||||
if (degrees_.empty() || degrees_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomSharpness: degrees vector has incorrect size: degrees.size()";
|
||||
return false;
|
||||
std::string err_msg = "RandomSharpness: degrees vector has incorrect size: degrees.size()";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
|
||||
|
@ -1019,16 +1094,18 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
|
|||
// RandomSolarizeOperation.
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
|
||||
|
||||
bool RandomSolarizeOperation::ValidateParams() {
|
||||
Status RandomSolarizeOperation::ValidateParams() {
|
||||
if (threshold_.size() != 2) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold vector has incorrect size: " << threshold_.size();
|
||||
return false;
|
||||
std::string err_msg = "RandomSolarize: threshold vector has incorrect size: " + std::to_string(threshold_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (threshold_.at(0) > threshold_.at(1)) {
|
||||
MS_LOG(ERROR) << "RandomSolarize: threshold must be passed in a min, max format";
|
||||
return false;
|
||||
std::string err_msg = "RandomSolarize: threshold must be passed in a min, max format";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
|
||||
|
@ -1039,12 +1116,14 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
|
|||
// RandomVerticalFlipOperation
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
|
||||
|
||||
bool RandomVerticalFlipOperation::ValidateParams() {
|
||||
Status RandomVerticalFlipOperation::ValidateParams() {
|
||||
if (probability_ < 0.0 || probability_ > 1.0) {
|
||||
MS_LOG(ERROR) << "RandomVerticalFlip: probability must be between 0.0 and 1.0.";
|
||||
return false;
|
||||
std::string err_msg = "RandomVerticalFlip: probability must be between 0.0 and 1.0.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
|
||||
|
@ -1055,12 +1134,14 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
|
|||
// RescaleOperation
|
||||
RescaleOperation::RescaleOperation(float rescale, float shift) : rescale_(rescale), shift_(shift) {}
|
||||
|
||||
bool RescaleOperation::ValidateParams() {
|
||||
Status RescaleOperation::ValidateParams() {
|
||||
if (rescale_ < 0) {
|
||||
MS_LOG(ERROR) << "Rescale: rescale must be greater than or equal to 0, got: rescale = " << rescale_;
|
||||
return false;
|
||||
std::string err_msg =
|
||||
"Rescale: rescale must be greater than or equal to 0, got: rescale = " + std::to_string(rescale_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> RescaleOperation::Build() {
|
||||
|
@ -1072,15 +1153,16 @@ std::shared_ptr<TensorOp> RescaleOperation::Build() {
|
|||
ResizeOperation::ResizeOperation(std::vector<int32_t> size, InterpolationMode interpolation)
|
||||
: size_(size), interpolation_(interpolation) {}
|
||||
|
||||
bool ResizeOperation::ValidateParams() {
|
||||
Status ResizeOperation::ValidateParams() {
|
||||
if (size_.empty() || size_.size() > 2) {
|
||||
MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size();
|
||||
return false;
|
||||
std::string err_msg = "Resize: size vector has incorrect size: " + std::to_string(size_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (!CheckVectorPositive(size_)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateVectorPositive("Resize", size_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> ResizeOperation::Build() {
|
||||
|
@ -1098,7 +1180,7 @@ std::shared_ptr<TensorOp> ResizeOperation::Build() {
|
|||
// RgbaToBgrOperation.
|
||||
RgbaToBgrOperation::RgbaToBgrOperation() {}
|
||||
|
||||
bool RgbaToBgrOperation::ValidateParams() { return true; }
|
||||
Status RgbaToBgrOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> RgbaToBgrOperation::Build() {
|
||||
std::shared_ptr<RgbaToBgrOp> tensor_op = std::make_shared<RgbaToBgrOp>();
|
||||
|
@ -1108,7 +1190,7 @@ std::shared_ptr<TensorOp> RgbaToBgrOperation::Build() {
|
|||
// RgbaToRgbOperation.
|
||||
RgbaToRgbOperation::RgbaToRgbOperation() {}
|
||||
|
||||
bool RgbaToRgbOperation::ValidateParams() { return true; }
|
||||
Status RgbaToRgbOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> RgbaToRgbOperation::Build() {
|
||||
std::shared_ptr<RgbaToRgbOp> tensor_op = std::make_shared<RgbaToRgbOp>();
|
||||
|
@ -1118,7 +1200,7 @@ std::shared_ptr<TensorOp> RgbaToRgbOperation::Build() {
|
|||
// SwapRedBlueOperation.
|
||||
SwapRedBlueOperation::SwapRedBlueOperation() {}
|
||||
|
||||
bool SwapRedBlueOperation::ValidateParams() { return true; }
|
||||
Status SwapRedBlueOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> SwapRedBlueOperation::Build() {
|
||||
std::shared_ptr<SwapRedBlueOp> tensor_op = std::make_shared<SwapRedBlueOp>();
|
||||
|
@ -1129,7 +1211,7 @@ std::shared_ptr<TensorOp> SwapRedBlueOperation::Build() {
|
|||
UniformAugOperation::UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops)
|
||||
: transforms_(transforms), num_ops_(num_ops) {}
|
||||
|
||||
bool UniformAugOperation::ValidateParams() { return true; }
|
||||
Status UniformAugOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> UniformAugOperation::Build() {
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_ops;
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -46,4 +46,4 @@ class PythonIterator : public IteratorConsumer {
|
|||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -176,4 +176,4 @@ class TreeGetters : public TreeConsumer {
|
|||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_
|
||||
|
|
|
@ -41,12 +41,12 @@ BatchNode::BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool dr
|
|||
|
||||
Status BatchNode::ValidateParams() {
|
||||
if (batch_size_ <= 0) {
|
||||
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (!cols_to_map_.empty()) {
|
||||
std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty.";
|
||||
std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
|
|
@ -62,8 +62,9 @@ std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
|||
|
||||
Status BucketBatchByLengthNode::ValidateParams() {
|
||||
if (element_length_function_ == nullptr && column_names_.size() != 1) {
|
||||
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
|
||||
std::to_string(column_names_.size());
|
||||
std::string err_msg =
|
||||
"BucketBatchByLengthNode: when element_length_function is not specified, size of column_name must be 1 but is: " +
|
||||
std::to_string(column_names_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
|
|
@ -23,11 +23,11 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32_t num_epoch) {
|
||||
// Check whether this function has been called before. If so, return fail
|
||||
// Check whether this function has been called before. If so, return failure
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tree_ == nullptr, "ExecutionTree is already built.");
|
||||
RETURN_UNEXPECTED_IF_NULL(root_ir);
|
||||
|
||||
// this will evolve in the long run
|
||||
// This will evolve in the long run
|
||||
tree_ = std::make_unique<ExecutionTree>();
|
||||
|
||||
std::shared_ptr<DatasetOp> root_op;
|
||||
|
@ -37,7 +37,7 @@ Status TreeAdapter::BuildAndPrepare(std::shared_ptr<api::Dataset> root_ir, int32
|
|||
// Prepare the tree
|
||||
RETURN_IF_NOT_OK(tree_->Prepare(num_epoch));
|
||||
|
||||
// after the tree is prepared, the col_name_id_map can safely be obtained
|
||||
// After the tree is prepared, the col_name_id_map can safely be obtained
|
||||
column_name_map_ = tree_->root()->column_name_id_map();
|
||||
|
||||
return Status::OK();
|
||||
|
@ -47,7 +47,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
|
|||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row);
|
||||
row->clear(); // make sure row is empty
|
||||
// cur_db_ being a nullptr means this is the first call to get_next, launch ExecutionTree
|
||||
// When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree
|
||||
if (cur_db_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag
|
||||
|
@ -77,7 +77,7 @@ Status TreeAdapter::DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_p
|
|||
RETURN_IF_NOT_OK(ops[i - 1]->AddChild(ops[i]));
|
||||
}
|
||||
|
||||
// build the children of ir, once they return, add the return value to *op
|
||||
// Build the children of ir, once they return, add the return value to *op
|
||||
for (std::shared_ptr<api::Dataset> child_ir : ir->children) {
|
||||
std::shared_ptr<DatasetOp> child_op;
|
||||
RETURN_IF_NOT_OK(DFSBuildTree(child_ir, &child_op));
|
||||
|
|
|
@ -37,7 +37,7 @@ class TreeAdapter {
|
|||
|
||||
~TreeAdapter() = default;
|
||||
|
||||
// This will construct a ExeTree from a Dataset root and Prepare() the ExeTree
|
||||
// This will construct an ExeTree from a Dataset root and Prepare() the ExeTree
|
||||
// This function is only meant to be called once and needs to be called before GetNext
|
||||
// ExeTree will be launched when the first GetNext is called
|
||||
Status BuildAndPrepare(std::shared_ptr<api::Dataset> root, int32_t num_epoch = -1);
|
||||
|
@ -47,15 +47,15 @@ class TreeAdapter {
|
|||
// 2. GetNext will return empty row when eoe/eof is obtained
|
||||
Status GetNext(TensorRow *);
|
||||
|
||||
// this function will return the column_name_map once BuildAndPrepare() is called
|
||||
// This function will return the column_name_map once BuildAndPrepare() is called
|
||||
std::unordered_map<std::string, int32_t> GetColumnNameMap() const { return column_name_map_; }
|
||||
|
||||
// this function returns the TaskGroup associated with ExeTree, this is needed by DeviceQueueConsumer
|
||||
// This function returns the TaskGroup associated with ExeTree. This is needed by DeviceQueueConsumer
|
||||
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
|
||||
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
|
||||
|
||||
private:
|
||||
// this RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. ir could build a vector of ops. In
|
||||
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
|
||||
// such case, the first node is returned. Op is added as child when the current function returns.
|
||||
Status DFSBuildTree(std::shared_ptr<api::Dataset> ir, std::shared_ptr<DatasetOp> *op);
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ enum class StatusCode : char {
|
|||
kTimeOut = 14,
|
||||
kBuddySpaceFull = 15,
|
||||
kNetWorkError = 16,
|
||||
kNotImplementedYet = 17,
|
||||
// Make this error code the last one. Add new error code above it.
|
||||
kUnexpectedError = 127
|
||||
};
|
||||
|
|
|
@ -17,13 +17,14 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TEXT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/text/vocab.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/core/data_type.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -56,7 +57,7 @@ class LookupOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
|
|
|
@ -17,10 +17,11 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -41,7 +42,7 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
|||
/// \return shared pointer to the newly created TensorOp.
|
||||
virtual std::shared_ptr<TensorOp> Build() = 0;
|
||||
|
||||
virtual bool ValidateParams() = 0;
|
||||
virtual Status ValidateParams() = 0;
|
||||
};
|
||||
|
||||
// Transform operations for performing data transformation.
|
||||
|
@ -73,7 +74,7 @@ class OneHotOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float num_classes_;
|
||||
|
@ -87,7 +88,7 @@ class TypeCastOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string data_type_;
|
||||
|
|
|
@ -17,10 +17,11 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_VISION_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_VISION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -326,7 +327,7 @@ class CenterCropOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> size_;
|
||||
|
@ -340,7 +341,7 @@ class CropOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> coordinates_;
|
||||
|
@ -355,7 +356,7 @@ class CutMixBatchOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
|
@ -371,7 +372,7 @@ class CutOutOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t length_;
|
||||
|
@ -387,7 +388,7 @@ class DecodeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
bool rgb_;
|
||||
|
@ -399,7 +400,7 @@ class HwcToChwOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
};
|
||||
|
||||
class MixUpBatchOperation : public TensorOperation {
|
||||
|
@ -410,7 +411,7 @@ class MixUpBatchOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
|
@ -424,7 +425,7 @@ class NormalizeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float> mean_;
|
||||
|
@ -440,7 +441,7 @@ class PadOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> padding_;
|
||||
|
@ -460,7 +461,7 @@ class RandomAffineOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float_t> degrees_; // min_degree, max_degree
|
||||
|
@ -479,7 +480,7 @@ class RandomColorOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float t_lb_;
|
||||
|
@ -495,7 +496,7 @@ class RandomColorAdjustOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float> brightness_;
|
||||
|
@ -514,7 +515,7 @@ class RandomCropOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> size_;
|
||||
|
@ -533,7 +534,7 @@ class RandomCropDecodeResizeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> size_;
|
||||
|
@ -551,7 +552,7 @@ class RandomHorizontalFlipOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float probability_;
|
||||
|
@ -565,7 +566,7 @@ class RandomPosterizeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> bit_range_;
|
||||
|
@ -582,7 +583,7 @@ class RandomResizedCropOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> size_;
|
||||
|
@ -601,7 +602,7 @@ class RandomRotationOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float> degrees_;
|
||||
|
@ -619,7 +620,7 @@ class RandomSharpnessOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<float> degrees_;
|
||||
|
@ -633,7 +634,7 @@ class RandomSolarizeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> threshold_;
|
||||
|
@ -647,7 +648,7 @@ class RandomVerticalFlipOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float probability_;
|
||||
|
@ -661,7 +662,7 @@ class RescaleOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
float rescale_;
|
||||
|
@ -677,7 +678,7 @@ class ResizeOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<int32_t> size_;
|
||||
|
@ -692,7 +693,7 @@ class RgbaToBgrOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
};
|
||||
|
||||
class RgbaToRgbOperation : public TensorOperation {
|
||||
|
@ -703,7 +704,7 @@ class RgbaToRgbOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
};
|
||||
|
||||
class SwapRedBlueOperation : public TensorOperation {
|
||||
|
@ -714,7 +715,7 @@ class SwapRedBlueOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
};
|
||||
|
||||
class UniformAugOperation : public TensorOperation {
|
||||
|
@ -725,7 +726,7 @@ class UniformAugOperation : public TensorOperation {
|
|||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> transforms_;
|
||||
|
|
Loading…
Reference in New Issue