forked from mindspore-Ecosystem/mindspore
Added fix for arm64
This commit is contained in:
parent
cc4aa65743
commit
3691260031
|
@ -222,7 +222,6 @@ std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const st
|
|||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to create a MindDataNode.
|
||||
std::shared_ptr<MindDataNode> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
|
||||
|
@ -244,6 +243,7 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f
|
|||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to create a MnistNode.
|
||||
std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage,
|
||||
|
@ -961,7 +961,7 @@ Status CLUENode::ValidateParams() {
|
|||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CLUENode: Invalid number of samples: " + num_samples_;
|
||||
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -1200,7 +1200,7 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
|
|||
schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "CocoNode::Build : Invalid task type: " << task_type;
|
||||
MS_LOG(ERROR) << "CocoNode::Build : Invalid task type";
|
||||
return {};
|
||||
}
|
||||
std::shared_ptr<CocoOp> op =
|
||||
|
@ -1234,7 +1234,7 @@ Status CSVNode::ValidateParams() {
|
|||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CSVNode: Invalid number of samples: " + num_samples_;
|
||||
std::string err_msg = "CSVNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -1560,7 +1560,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
|
|||
// ValideParams for RandomNode
|
||||
Status RandomNode::ValidateParams() {
|
||||
if (total_rows_ < 0) {
|
||||
std::string err_msg = "RandomNode: total_rows must be greater than or equal 0, now get " + total_rows_;
|
||||
std::string err_msg =
|
||||
"RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -1638,7 +1639,7 @@ Status TextFileNode::ValidateParams() {
|
|||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "TextFileNode: Invalid number of samples: " + num_samples_;
|
||||
std::string err_msg = "TextFileNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -1858,7 +1859,7 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
|
|||
|
||||
Status BatchNode::ValidateParams() {
|
||||
if (batch_size_ <= 0) {
|
||||
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + batch_size_;
|
||||
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -2158,7 +2159,7 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
|
|||
Status RepeatNode::ValidateParams() {
|
||||
if (repeat_count_ <= 0 && repeat_count_ != -1) {
|
||||
std::string err_msg =
|
||||
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + repeat_count_;
|
||||
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -2185,7 +2186,7 @@ std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
|
|||
// Function to validate the parameters for ShuffleNode
|
||||
Status ShuffleNode::ValidateParams() {
|
||||
if (shuffle_size_ <= 1) {
|
||||
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + shuffle_size_;
|
||||
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -2210,7 +2211,7 @@ std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
|
|||
// Function to validate the parameters for SkipNode
|
||||
Status SkipNode::ValidateParams() {
|
||||
if (skip_count_ <= -1) {
|
||||
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + skip_count_;
|
||||
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -2234,7 +2235,8 @@ std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
|
|||
// Function to validate the parameters for TakeNode
|
||||
Status TakeNode::ValidateParams() {
|
||||
if (take_count_ <= 0 && take_count_ != -1) {
|
||||
std::string err_msg = "Take: take_count should be either -1 or positive integer, take_count: " + take_count_;
|
||||
std::string err_msg =
|
||||
"Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_pk_sample.h"
|
||||
|
@ -30,6 +31,7 @@
|
|||
#include "minddata/mindrecord/include/shard_sequential_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -150,12 +152,14 @@ std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
|
||||
num_samples_, offset_);
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// PKSampler
|
||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
|
@ -181,6 +185,7 @@ std::shared_ptr<Sampler> PKSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
|
||||
|
@ -193,6 +198,7 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
|||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// RandomSampler
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
|
||||
|
@ -214,6 +220,7 @@ std::shared_ptr<Sampler> RandomSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
bool reshuffle_each_epoch_ = true;
|
||||
|
@ -222,6 +229,7 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset
|
|||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SequentialSampler
|
||||
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
|
||||
|
@ -248,12 +256,14 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// SubsetRandomSampler
|
||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
|
||||
|
@ -275,12 +285,14 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
|
||||
|
||||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
// WeightedRandomSampler
|
||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "minddata/dataset/util/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
|
|
@ -19,7 +19,10 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -45,10 +48,12 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<Sampler> Build() = 0;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
|
||||
#endif
|
||||
};
|
||||
|
||||
class DistributedSamplerObj;
|
||||
|
@ -123,7 +128,9 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<Sampler> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
|
@ -145,7 +152,9 @@ class PKSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<Sampler> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
|
@ -163,7 +172,9 @@ class RandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<Sampler> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
|
@ -180,7 +191,9 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<Sampler> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
|
@ -197,7 +210,9 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<Sampler> Build() override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
|
|
Loading…
Reference in New Issue