!14907 Address codex warning for CLUENode::CreateKepMapForBuild() r1.2 branch

From: @lixiachen
Reviewed-by: @robingrosman,@liucunwei
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-04-10 16:35:16 +08:00 committed by Gitee
commit 1766f12888
2 changed files with 98 additions and 79 deletions

View File

@ -83,88 +83,87 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
return res;
}
std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() {
std::map<std::string, std::string> CLUENode::CreateKeyMapForAFQMCOrCMNLITask() {
std::map<std::string, std::string> key_map;
if (task_ == "AFQMC") {
if (usage_ == "train" || usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
}
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
} else { // usage_ == "test"
key_map["id"] = "id";
}
if (task_ == "CMNLI") {
if (usage_ == "train" || usage_ == "eval") {
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
key_map["label"] = "label";
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
}
key_map["sentence1"] = "sentence1";
key_map["sentence2"] = "sentence2";
return key_map;
}
std::map<std::string, std::string> CLUENode::CreateKeyMapForCSLTask() {
std::map<std::string, std::string> key_map;
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
}
if (task_ == "CSL") {
if (usage_ == "train" || usage_ == "eval") {
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
key_map["label"] = "label";
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
}
key_map["id"] = "id";
key_map["abst"] = "abst";
key_map["keyword"] = "keyword";
return key_map;
}
std::map<std::string, std::string> CLUENode::CreateKeyMapForIFLYTEKTask() {
std::map<std::string, std::string> key_map;
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
} else { // usage_ == "test"
key_map["id"] = "id";
}
if (task_ == "IFLYTEK") {
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_des"] = "label_des";
key_map["sentence"] = "sentence";
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence"] = "sentence";
}
key_map["sentence"] = "sentence";
return key_map;
}
std::map<std::string, std::string> CLUENode::CreateKeyMapForTNEWSTask() {
std::map<std::string, std::string> key_map;
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
} else { // usage_ == "test"
key_map["id"] = "id";
}
if (task_ == "TNEWS") {
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
key_map["label_desc"] = "label_desc";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
} else { // usage_ == "test"
key_map["id"] = "id";
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
}
key_map["sentence"] = "sentence";
key_map["keywords"] = "keywords";
return key_map;
}
std::map<std::string, std::string> CLUENode::CreateKeyMapForWSCTask() {
std::map<std::string, std::string> key_map;
if (usage_ == "train" || usage_ == "eval") {
key_map["label"] = "label";
}
if (task_ == "WSC") {
if (usage_ == "train" || usage_ == "eval") {
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["label"] = "label";
key_map["text"] = "text";
} else { // usage_ == "test"
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["text"] = "text";
}
key_map["span1_index"] = "target/span1_index";
key_map["span2_index"] = "target/span2_index";
key_map["span1_text"] = "target/span1_text";
key_map["span2_text"] = "target/span2_text";
key_map["idx"] = "idx";
key_map["text"] = "text";
return key_map;
}
std::map<std::string, std::string> CLUENode::CreateKeyMap() {
std::map<std::string, std::string> key_map;
if (task_ == "AFQMC" || task_ == "CMNLI") {
key_map = CreateKeyMapForAFQMCOrCMNLITask();
} else if (task_ == "CSL") {
key_map = CreateKeyMapForCSLTask();
} else if (task_ == "IFLYTEK") {
key_map = CreateKeyMapForIFLYTEKTask();
} else if (task_ == "TNEWS") {
key_map = CreateKeyMapForTNEWSTask();
} else if (task_ == "WSC") {
key_map = CreateKeyMapForWSCTask();
}
return key_map;
}
// Function to build CLUENode
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
auto key_map = CreateKeyMapForBuild();
auto key_map = CreateKeyMap();
ColKeyMap ck_map;
for (auto &p : key_map) {
ck_map.insert({p.first, split(p.second, '/')});
@ -246,11 +245,11 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
return Status::OK();
}
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
// CLUE by itself is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we setup the sampler for a leaf node that does not use sampling.
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent
// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is
// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing
// sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not use
// sampling.
Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);

View File

@ -50,10 +50,6 @@ class CLUENode : public NonMappableSourceNode {
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief Generate a key map to be used in Build() according to usage and task
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForBuild();
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
@ -111,6 +107,30 @@ class CLUENode : public NonMappableSourceNode {
/// \return A string vector
std::vector<std::string> split(const std::string &s, char delim);
/// \brief Generate a key map for AFQMC or CMNLI task according to usage
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForAFQMCOrCMNLITask();
/// \brief Generate a key map for CSL task according to usage
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForCSLTask();
/// \brief Generate a key map for IFLYTEK task according to usage
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForIFLYTEKTask();
/// \brief Generate a key map for TNEWS task according to usage
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForTNEWSTask();
/// \brief Generate a key map for WSC task according to usage
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMapForWSCTask();
/// \brief Generate a key map to be used in Build() according to usage and task
/// \return The generated key map
std::map<std::string, std::string> CreateKeyMap();
std::vector<std::string> dataset_files_;
std::string task_;
std::string usage_;