Added random node fix

This commit is contained in:
Eric 2020-11-11 14:48:21 -05:00
parent 7d6039d384
commit a21eb2d527
7 changed files with 28 additions and 36 deletions

View File

@ -691,10 +691,12 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
}
return vocab;
}
#endif
std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remainder) {
return std::make_shared<BatchDataset>(shared_from_this(), batch_size, drop_remainder);
}
#endif
SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {}
// SchemaObj init function
@ -969,16 +971,14 @@ VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task,
#endif
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
auto ds =
std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler), cache);
std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema), std::move(columns_list), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
RandomDataDataset::RandomDataDataset(const int32_t &total_rows, std::string schema_path,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list),
std::move(sampler), cache);
std::shared_ptr<DatasetCache> cache) {
auto ds = std::make_shared<RandomNode>(total_rows, std::move(schema_path), std::move(columns_list), cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
#ifndef ENABLE_ANDROID

View File

@ -158,6 +158,7 @@ std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
// Input validation
return op->ValidateParams() ? op : nullptr;
}
#endif
// Function to create NormalizeOperation.
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
@ -166,6 +167,7 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
return op->ValidateParams() ? op : nullptr;
}
#ifndef ENABLE_ANDROID
// Function to create PadOperation.
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
BorderType padding_mode) {
@ -702,6 +704,7 @@ Status MixUpBatchOperation::ValidateParams() {
std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
#endif
// NormalizeOperation
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
@ -736,6 +739,7 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
}
#ifndef ENABLE_ANDROID
// PadOperation
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}

View File

@ -36,8 +36,6 @@ Status RandomNode::ValidateParams() {
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetSampler("RandomNode", sampler_));
if (!columns_list_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_));
}
@ -89,6 +87,14 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
data_schema->LoadSchemaString(schema_json_string, columns_to_load);
}
}
// RandomOp by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
// RandomOp doesn't support sampler, should not support sharding, select sampler should just be sequential.
std::shared_ptr<SamplerObj> sampler_ = SelectSampler(total_rows_, false, 1, 0);
std::shared_ptr<RandomDataOp> op;
op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_,
std::move(data_schema), std::move(sampler_->Build()));

View File

@ -36,22 +36,20 @@ class RandomNode : public DatasetNode {
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(""),
schema_(std::move(schema)),
columns_list_(columns_list),
sampler_(std::move(sampler)) {}
columns_list_(columns_list) {}
/// \brief Constructor
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)),
total_rows_(total_rows),
schema_path_(schema_path),
columns_list_(columns_list),
sampler_(std::move(sampler)) {}
columns_list_(columns_list) {}
/// \brief Destructor
~RandomNode() = default;

View File

@ -810,11 +810,10 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
class RandomDataDataset : public Dataset {
public:
RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
const std::vector<std::string> &columns_list, const std::shared_ptr<SamplerObj> &sampler,
std::shared_ptr<DatasetCache> cache);
const std::vector<std::string> &columns_list, std::shared_ptr<DatasetCache> cache);
RandomDataDataset(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
std::shared_ptr<DatasetCache> cache);
};
/// \brief Function to create a RandomDataset
@ -829,16 +828,13 @@ class RandomDataDataset : public Dataset {
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {},
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
std::shared_ptr<RandomDataDataset> ds;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list),
std::move(sampler), cache);
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list), cache);
} else {
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler),
cache);
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), cache);
}
return ds;
}

View File

@ -434,7 +434,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestCacheRandomDataCApi) {
std::shared_ptr<SchemaObj> schema = Schema();
schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2});
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
std::shared_ptr<Dataset> ds = RandomData(4, schema, {}, RandomSampler(), some_cache);
std::shared_ptr<Dataset> ds = RandomData(4, schema, {}, some_cache);
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds

View File

@ -402,18 +402,6 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic7) {
GlobalContext::config_manager()->set_seed(curr_seed);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetWithNullSampler.";
// Create a RandomDataset
std::shared_ptr<SchemaObj> schema = Schema();
schema->add_column("image", mindspore::TypeId::kNumberTypeUInt8, {2});
schema->add_column("label", mindspore::TypeId::kNumberTypeUInt8, {1});
std::shared_ptr<Dataset> ds = RandomData(50, schema, {}, nullptr);
// Expect failure: sampler can not be nullptr
EXPECT_EQ(ds->CreateIterator(), nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetDuplicateColumnName.";