!14907 Address codex warning for CLUENode::CreateKepMapForBuild() r1.2 branch
From: @lixiachen Reviewed-by: @robingrosman,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
1766f12888
|
@ -83,88 +83,87 @@ std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForBuild() {
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForAFQMCOrCMNLITask() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (task_ == "AFQMC") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
}
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
}
|
||||
if (task_ == "CMNLI") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
}
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
return key_map;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForCSLTask() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
if (task_ == "CSL") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
}
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
return key_map;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForIFLYTEKTask() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
}
|
||||
if (task_ == "IFLYTEK") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
}
|
||||
key_map["sentence"] = "sentence";
|
||||
return key_map;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForTNEWSTask() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
}
|
||||
if (task_ == "TNEWS") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else { // usage_ == "test"
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
}
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
return key_map;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMapForWSCTask() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
if (task_ == "WSC") {
|
||||
if (usage_ == "train" || usage_ == "eval") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
} else { // usage_ == "test"
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["text"] = "text";
|
||||
}
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["text"] = "text";
|
||||
return key_map;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> CLUENode::CreateKeyMap() {
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (task_ == "AFQMC" || task_ == "CMNLI") {
|
||||
key_map = CreateKeyMapForAFQMCOrCMNLITask();
|
||||
} else if (task_ == "CSL") {
|
||||
key_map = CreateKeyMapForCSLTask();
|
||||
} else if (task_ == "IFLYTEK") {
|
||||
key_map = CreateKeyMapForIFLYTEKTask();
|
||||
} else if (task_ == "TNEWS") {
|
||||
key_map = CreateKeyMapForTNEWSTask();
|
||||
} else if (task_ == "WSC") {
|
||||
key_map = CreateKeyMapForWSCTask();
|
||||
}
|
||||
return key_map;
|
||||
}
|
||||
|
||||
// Function to build CLUENode
|
||||
Status CLUENode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto key_map = CreateKeyMapForBuild();
|
||||
auto key_map = CreateKeyMap();
|
||||
ColKeyMap ck_map;
|
||||
for (auto &p : key_map) {
|
||||
ck_map.insert({p.first, split(p.second, '/')});
|
||||
|
@ -246,11 +245,11 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class.
|
||||
// CLUE by itself is a non-mappable dataset that does not support sampling.
|
||||
// However, if a cache operator is injected at some other place higher in the tree, that cache can
|
||||
// inherit this sampler from the leaf, providing sampling support from the caching layer.
|
||||
// That is why we setup the sampler for a leaf node that does not use sampling.
|
||||
// Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent
|
||||
// class. CLUE by itself is a non-mappable dataset that does not support sampling. However, if a cache operator is
|
||||
// injected at some other place higher in the tree, that cache can inherit this sampler from the leaf, providing
|
||||
// sampling support from the caching layer. That is why we setup the sampler for a leaf node that does not use
|
||||
// sampling.
|
||||
Status CLUENode::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) {
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
*sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_);
|
||||
|
|
|
@ -50,10 +50,6 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief Generate a key map to be used in Build() according to usage and task
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForBuild();
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
/// \return Status Status::OK() if build successfully
|
||||
|
@ -111,6 +107,30 @@ class CLUENode : public NonMappableSourceNode {
|
|||
/// \return A string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
/// \brief Generate a key map for AFQMC or CMNLI task according to usage
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForAFQMCOrCMNLITask();
|
||||
|
||||
/// \brief Generate a key map for CSL task according to usage
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForCSLTask();
|
||||
|
||||
/// \brief Generate a key map for IFLYTEK task according to usage
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForIFLYTEKTask();
|
||||
|
||||
/// \brief Generate a key map for TNEWS task according to usage
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForTNEWSTask();
|
||||
|
||||
/// \brief Generate a key map for WSC task according to usage
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMapForWSCTask();
|
||||
|
||||
/// \brief Generate a key map to be used in Build() according to usage and task
|
||||
/// \return The generated key map
|
||||
std::map<std::string, std::string> CreateKeyMap();
|
||||
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
|
|
Loading…
Reference in New Issue