support parallel computing for conflicts

This commit is contained in:
laiyongqiang 2021-01-12 17:34:21 +08:00
parent 8d05cd1ffd
commit af3c98e6ad
3 changed files with 137 additions and 80 deletions

View File

@ -34,10 +34,12 @@
#include "backend/optimizer/common/helper.h"
#include "utils/ms_context.h"
#include "debug/common.h"
#include "common/thread_pool.h"
namespace mindspore {
namespace somas {
constexpr auto kGapSize = 512;
constexpr auto kParallelComputeSizeThreshold = 2000;
std::map<TensorType, std::string> tensor_type_name_map = {{kCommon, "Common"},
{kOutputOnly, "OutputOnly"},
{kWorkspace, "Workspace"},
@ -641,7 +643,7 @@ void Somas::ComputeConflictPairs() {
MS_LOG(INFO) << "End Preprocessing Conflicts";
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
auto start_conflict = std::chrono::system_clock::now();
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
// Loop to add edges within each stream (node order within stream)
@ -708,76 +710,107 @@ void Somas::ComputeConflictPairs() {
MS_LOG(INFO) << "Start Tensor Relation Computing";
count = tensors_list_.back()->GetId() + 1;
for (size_t i = 0; i < count; i++) {
tensor_relation.emplace_back(count);
reuse_matrix_.emplace_back(count);
}
for (size_t i = 0; i < tensors_list_.size(); i++) {
auto t0 = tensors_list_[i];
if (t0->IsLifelong() || t0->IsRefOverlap() || t0->GetAlignedSize() == 0) {
continue;
if (tensors_list_.size() < kParallelComputeSizeThreshold) {
ComputeMultiTensorConflicts(tensors_list_, tensors_list_, nodes_dependency, &reuse_matrix_);
} else {
MS_LOG(INFO) << "Tensor Num " << tensors_list_.size() << " is larger than " << kParallelComputeSizeThreshold;
MS_LOG(INFO) << "Enter Multi-Thread Mode...";
size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
MS_LOG(INFO) << "Threads Num is " << process_num;
size_t start_index = 0;
size_t total_size = tensors_list_.size();
size_t job_size = total_size / process_num;
if (job_size == 0) {
job_size = total_size;
}
size_t t0_src_node = t0->GetSourceNode()->GetId();
for (size_t j = i + 1; j < tensors_list_.size(); j++) {
auto t1 = tensors_list_[j];
if (t0 == t1 || t1->IsLifelong() || t1->IsRefOverlap() || t1->GetAlignedSize() == 0) {
continue;
}
size_t t1_src_node = t1->GetSourceNode()->GetId();
if (t0_src_node == t1_src_node) {
continue;
}
bool reuse = true;
bool all_dst_depend = false;
// check t0's all consumers is t1's source node's dependency or not
for (const auto &dst_node : t0->destinations_) {
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) {
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1
// produced
reuse = false;
all_dst_depend = false;
break;
} else if (t1_src_node == dst_node->GetId()) {
// t0 is t1's source node's input, can't reuse
reuse = false;
all_dst_depend = true;
break;
} else {
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced
reuse = true;
all_dst_depend = true;
}
}
if (all_dst_depend == false) {
// check t1's all consumers is t0's source node's dependency or not
reuse = true;
for (const auto &dst_node : t1->destinations_) {
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) {
reuse = false;
all_dst_depend = false;
break;
} else if (t0_src_node == dst_node->GetId()) {
reuse = false;
all_dst_depend = true;
break;
} else {
reuse = true;
all_dst_depend = true;
}
}
}
if (all_dst_depend == true && reuse == true) {
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId());
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId());
}
std::vector<common::Task> tasks;
while (start_index < total_size) {
size_t end_index = (start_index + job_size) > total_size ? total_size : start_index + job_size;
auto jobs = std::vector<SomasTensorPtr>(tensors_list_.begin() + start_index, tensors_list_.begin() + end_index);
auto task = [this, jobs, &nodes_dependency]() {
this->ComputeMultiTensorConflicts(jobs, tensors_list_, nodes_dependency, &reuse_matrix_);
return common::SUCCESS;
};
tasks.emplace_back(task);
start_index += job_size;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
MS_LOG(INFO) << "End Tensor Relation Computing";
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)";
auto end_conflict = std::chrono::system_clock::now();
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)(time taken "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end_conflict - start_conflict).count() << "ms)";
}
void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
const std::vector<SomasTensorPtr> &all_tensors_list,
const vector<DynamicBitSet> &nodes_dependency,
std::vector<DynamicBitSet> *tensor_relation) const {
auto start = std::chrono::system_clock::now();
MS_LOG(INFO) << "Start Computing Conflicts Pairs, tensors list size is " << calc_tensors_list.size();
for (size_t i = 0; i < calc_tensors_list.size(); i++) {
auto calc_tensor = calc_tensors_list[i];
if (calc_tensor->IsLifelong() || calc_tensor->IsRefOverlap() || calc_tensor->GetAlignedSize() == 0) {
continue;
}
ComputeOneTensorConflicts(calc_tensor, all_tensors_list, nodes_dependency, tensor_relation);
}
auto end = std::chrono::system_clock::now();
MS_LOG(INFO) << "End Computing Conflicts Pairs (time taken "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms)";
}
void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
const std::vector<SomasTensorPtr> &all_tensors_list,
const vector<DynamicBitSet> &nodes_dependency,
std::vector<DynamicBitSet> *tensor_relation) const {
for (size_t j = 0; j < all_tensors_list.size(); j++) {
auto target_tensor = all_tensors_list[j];
if (calc_tensor == target_tensor || target_tensor->IsLifelong() || target_tensor->IsRefOverlap() ||
target_tensor->GetAlignedSize() == 0) {
continue;
}
size_t calc_src_node = calc_tensor->GetSourceNode()->GetId();
size_t target_src_node = target_tensor->GetSourceNode()->GetId();
if (calc_src_node == target_src_node) {
continue;
}
if ((*tensor_relation)[calc_tensor->GetId()].IsBitTrue(target_tensor->GetId()) ||
(*tensor_relation)[target_tensor->GetId()].IsBitTrue(calc_tensor->GetId())) {
continue;
}
bool reuse = true;
// check calc_tensor's all consumers is target_tensor's source node's dependency or not
for (const auto &dst_node : calc_tensor->destinations_) {
if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) {
// calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or
// not when target_tensor produced
reuse = false;
break;
} else if (target_src_node == dst_node->GetId()) {
// calc_tensor is target_tensor's source node's input, can't reuse
reuse = false;
break;
} else {
// calc_tensor's consumer is in target_tensor's source node's dependency, this consumer is done when
// target_tensor produced
reuse = true;
}
}
if (reuse) {
// calc_tensor and target_tensor have dependencies so they can reuse each other
(*tensor_relation)[calc_tensor->GetId()].SetBitTrue(target_tensor->GetId());
(*tensor_relation)[target_tensor->GetId()].SetBitTrue(calc_tensor->GetId());
}
}
}
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); }
@ -798,13 +831,13 @@ bool Somas::Assign(const session::KernelGraph *graph) {
// Keep all constraints for first tensor in list
size_t tid_0 = ref_node_list[0];
for (SomasTensorPtr tensor : tensors_list_) {
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) {
if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) {
continue;
}
for (size_t tid : ref_node_list) {
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) {
tensor_relation[tid_0].SetBitFalse(tensor->GetId());
tensor_relation[tensor->GetId()].SetBitFalse(tid_0);
if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) {
reuse_matrix_[tid_0].SetBitFalse(tensor->GetId());
reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0);
break;
}
}
@ -924,23 +957,21 @@ bool Somas::Assign(const session::KernelGraph *graph) {
for (auto ref_overlap_list : ref_overlap_constraints_) {
for (size_t tid_1 : ref_overlap_list) {
for (size_t tid_2 : ref_overlap_list) {
tensor_relation[tid_1].SetBitTrue(tid_2);
tensor_relation[tid_2].SetBitTrue(tid_1);
reuse_matrix_[tid_1].SetBitTrue(tid_2);
reuse_matrix_[tid_2].SetBitTrue(tid_1);
}
}
}
MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap";
#ifdef SOMAS_DEBUG
// Compute number of constraints for each tensor
auto tensors_num = tensors_list_.size();
for (auto tensor1 : tensors_list_) {
size_t count_constraints = 0;
for (auto tensor2 : tensors_list_) {
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) {
count_constraints++;
}
}
tensor1->num_constraints_ = count_constraints;
auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum();
tensor1->num_constraints_ = tensors_num - ones_num;
}
#endif
// Prepare solver info
MS_LOG(INFO) << "Start Loop to create solver info";
@ -960,7 +991,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
}
somas_solver_ = std::make_shared<SomasSolverPre>();
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation,
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_,
contiguous_tensors_list_removed_ref, false);
MS_LOG(INFO) << "End Solving";
if (status != SUCCESS) {

View File

@ -53,7 +53,7 @@ class Somas {
void DumpSomasMemoryIR(const string filename);
static bool NodeSort(SomasNodePtr, SomasNodePtr);
std::vector<DynamicBitSet> tensor_relation;
std::vector<DynamicBitSet> reuse_matrix_;
private:
// Maps
@ -128,6 +128,14 @@ class Somas {
SomasParameterPtr CreateSomasParameters(AnfNodePtr node, size_t index);
void InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel);
void InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel);
void ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
const std::vector<SomasTensorPtr> &all_tensors_list,
const vector<DynamicBitSet> &nodes_dependency,
std::vector<DynamicBitSet> *tensor_relation) const;
void ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
const std::vector<SomasTensorPtr> &all_tensors_list,
const vector<DynamicBitSet> &nodes_dependency,
std::vector<DynamicBitSet> *tensor_relation) const;
};
using SomasPtr = std::shared_ptr<Somas>;

View File

@ -92,6 +92,24 @@ class DynamicBitSet {
bool IsBitTrue(size_t index) const { return (bit_[GetIndex(index)] & GetBitMask(index)) != 0x0; }
size_t CountOnesNum() const {
size_t ret = 0;
static char ones_num_in_hex[] = "\0\1\1\2\1\2\2\3\1\2\2\3\2\3\3\4";
for (size_t i = 0; i < bit_size_; i++) {
auto value = bit_[i];
if (value == 0) {
continue;
}
char *char_value = reinterpret_cast<char *>(&value);
for (size_t j = 0; j < bit_width_ / CHAR_BIT; j++) {
ret += ones_num_in_hex[char_value[j] & 0xF];
char_value[j] >>= 4;
ret += ones_num_in_hex[char_value[j] & 0xF];
}
}
return ret;
}
void Log() {
std::cout << "Start Print Bitset ";
for (size_t i = 0; i < bit_size_; i++) {