Added fix for arm64

This commit is contained in:
Eric 2020-10-20 11:07:24 -04:00
parent cc4aa65743
commit 3691260031
4 changed files with 41 additions and 12 deletions

View File

@ -222,7 +222,6 @@ std::shared_ptr<ManifestNode> Manifest(const std::string &dataset_file, const st
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
#endif
// Function to create a MindDataNode.
std::shared_ptr<MindDataNode> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
@ -244,6 +243,7 @@ std::shared_ptr<MindDataNode> MindData(const std::vector<std::string> &dataset_f
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
#endif
// Function to create a MnistNode.
std::shared_ptr<MnistNode> Mnist(const std::string &dataset_dir, const std::string &usage,
@ -961,7 +961,7 @@ Status CLUENode::ValidateParams() {
}
if (num_samples_ < 0) {
std::string err_msg = "CLUENode: Invalid number of samples: " + num_samples_;
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1200,7 +1200,7 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
schema->AddColumn(ColDescriptor(std::string("area"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
break;
default:
MS_LOG(ERROR) << "CocoNode::Build : Invalid task type: " << task_type;
MS_LOG(ERROR) << "CocoNode::Build : Invalid task type";
return {};
}
std::shared_ptr<CocoOp> op =
@ -1234,7 +1234,7 @@ Status CSVNode::ValidateParams() {
}
if (num_samples_ < 0) {
std::string err_msg = "CSVNode: Invalid number of samples: " + num_samples_;
std::string err_msg = "CSVNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1560,7 +1560,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
// ValideParams for RandomNode
Status RandomNode::ValidateParams() {
if (total_rows_ < 0) {
std::string err_msg = "RandomNode: total_rows must be greater than or equal 0, now get " + total_rows_;
std::string err_msg =
"RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1638,7 +1639,7 @@ Status TextFileNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));
if (num_samples_ < 0) {
std::string err_msg = "TextFileNode: Invalid number of samples: " + num_samples_;
std::string err_msg = "TextFileNode: Invalid number of samples: " + std::to_string(num_samples_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1858,7 +1859,7 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
Status BatchNode::ValidateParams() {
if (batch_size_ <= 0) {
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + batch_size_;
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2158,7 +2159,7 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg =
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + repeat_count_;
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2185,7 +2186,7 @@ std::vector<std::shared_ptr<DatasetOp>> ShuffleNode::Build() {
// Function to validate the parameters for ShuffleNode
Status ShuffleNode::ValidateParams() {
if (shuffle_size_ <= 1) {
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + shuffle_size_;
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2210,7 +2211,7 @@ std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
// Function to validate the parameters for SkipNode
Status SkipNode::ValidateParams() {
if (skip_count_ <= -1) {
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + skip_count_;
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2234,7 +2235,8 @@ std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
// Function to validate the parameters for TakeNode
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg = "Take: take_count should be either -1 or positive integer, take_count: " + take_count_;
std::string err_msg =
"Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

View File

@ -23,6 +23,7 @@
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#ifndef ENABLE_ANDROID
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
@ -30,6 +31,7 @@
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/dataset/util/random.h"
#endif
namespace mindspore {
namespace dataset {
@ -150,12 +152,14 @@ std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
#endif
// PKSampler
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
@ -181,6 +185,7 @@ std::shared_ptr<Sampler> PKSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
@ -193,6 +198,7 @@ std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
return mind_sampler;
}
#endif
// RandomSampler
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
@ -214,6 +220,7 @@ std::shared_ptr<Sampler> RandomSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
bool reshuffle_each_epoch_ = true;
@ -222,6 +229,7 @@ std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset
return mind_sampler;
}
#endif
// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
@ -248,12 +256,14 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
return mind_sampler;
}
#endif
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
@ -275,12 +285,14 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
return sampler;
}
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
return mind_sampler;
}
#endif
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)

View File

@ -20,7 +20,7 @@
#include <cstdint>
#include <string>
#include "utils/log_adapter.h"
#include "minddata/dataset/util/log_adapter.h"
namespace mindspore {
namespace dataset {

View File

@ -19,7 +19,10 @@
#include <vector>
#include <memory>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#endif
namespace mindspore {
namespace dataset {
@ -45,10 +48,12 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<Sampler> Build() = 0;
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
#endif
};
class DistributedSamplerObj;
@ -123,7 +128,9 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
@ -145,7 +152,9 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
@ -163,7 +172,9 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
@ -180,7 +191,9 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;
@ -197,7 +210,9 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
#endif
bool ValidateParams() override;