forked from mindspore-Ecosystem/mindspore
Added random node fix
This commit is contained in:
parent
7d6039d384
commit
a21eb2d527
|
@ -691,10 +691,12 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
|||
}
|
||||
return vocab;
|
||||
}
|
||||
|
||||
#endif
|
||||
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
|
||||
return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
|
||||
}
|
||||
#endif
|
||||
|
||||
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
|
||||
|
||||
// SchemaObj init function
|
||||
|
@ -969,16 +971,14 @@ VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task,
|
|||
#endif
|
||||
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
|
||||
auto ds =
|
||||
std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache);
|
||||
std::shared_ptr<DatasetCache> cache) {
|
||||
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::string schema_path,
|
||||
const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
|
||||
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list),
|
||||
std::move(sampler), cache);
|
||||
std::shared_ptr<DatasetCache> cache) {
|
||||
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list), cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
|
|
@ -158,6 +158,7 @@ std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
|
|||
// Input validation
|
||||
return op->ValidateParams() ? op : nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Function to create NormalizeOperation.
|
||||
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
|
||||
|
@ -166,6 +167,7 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
|
|||
return op->ValidateParams() ? op : nullptr;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Function to create PadOperation.
|
||||
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
|
||||
BorderType padding_mode) {
|
||||
|
@ -702,6 +704,7 @@ Status MixUpBatchOperation::ValidateParams() {
|
|||
|
||||
std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
|
||||
|
||||
#endif
|
||||
// NormalizeOperation
|
||||
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
|
||||
|
||||
|
@ -736,6 +739,7 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
|
|||
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// PadOperation
|
||||
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
|
||||
|
|
|
@ -36,8 +36,6 @@ Status RandomNode::ValidateParams() {
|
|||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomNode", sampler_));
|
||||
|
||||
if (!columns_list_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_));
|
||||
}
|
||||
|
@ -89,6 +87,14 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
|
||||
}
|
||||
}
|
||||
|
||||
// RandomOp by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we save the sampler here in a leaf node that does not use sampling.
|
||||
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
|
||||
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(total_rows_, false, 1, 0);
|
||||
|
||||
std::shared_ptr<RandomDataOp> op;
|
||||
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
|
||||
std::move(data_schema), std::move(sampler_->Build()));
|
||||
|
|
|
@ -36,22 +36,20 @@ class RandomNode : public DatasetNode {
|
|||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(""),
|
||||
schema_(std::move(schema)),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
columns_list_(columns_list) {}
|
||||
|
||||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(schema_path),
|
||||
columns_list_(columns_list),
|
||||
sampler_(std::move(sampler)) {}
|
||||
columns_list_(columns_list) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomNode() = default;
|
||||
|
|
|
@ -810,11 +810,10 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
class RandomDataDataset : public Dataset {
|
||||
public:
|
||||
RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list, const std::shared_ptr<SamplerObj> &sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
const std::vector<std::string> &columns_list, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
RandomDataDataset(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
};
|
||||
|
||||
/// \brief Function to create a RandomDataset
|
||||
|
@ -829,16 +828,13 @@ class RandomDataDataset : public Dataset {
|
|||
template <typename T = std::shared_ptr<SchemaObj>>
|
||||
std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
|
||||
const std::vector<std::string> &columns_list = {},
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
std::shared_ptr<RandomDataDataset> ds;
|
||||
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
||||
std::shared_ptr<SchemaObj> schema_obj = schema;
|
||||
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list),
|
||||
std::move(sampler), cache);
|
||||
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list), cache);
|
||||
} else {
|
||||
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler),
|
||||
cache);
|
||||
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), cache);
|
||||
}
|
||||
return ds;
|
||||
}
|
||||
|
|
|
@ -434,7 +434,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheRandomDataCApi) {
|
|||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2});
|
||||
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
|
||||
std::shared_ptr<Dataset> ds = RandomData(4, schema, {}, RandomSampler(), some_cache);
|
||||
std::shared_ptr<Dataset> ds = RandomData(4, schema, {}, some_cache);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
|
|
|
@ -402,18 +402,6 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic7) {
|
|||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetWithNullSampler.";
|
||||
|
||||
// Create a RandomDataset
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2});
|
||||
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema, {}, nullptr);
|
||||
// Expect failure: sampler can not be nullptr
|
||||
EXPECT_EQ(ds->CreateIterator(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetDuplicateColumnName.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue