forked from mindspore-Ecosystem/mindspore
support parallel computing for conflicts
This commit is contained in:
parent
8d05cd1ffd
commit
af3c98e6ad
|
@ -34,10 +34,12 @@
|
|||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/common.h"
|
||||
#include "common/thread_pool.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace somas {
|
||||
constexpr auto kGapSize = 512;
|
||||
constexpr auto kParallelComputeSizeThreshold = 2000;
|
||||
std::map<TensorType, std::string> tensor_type_name_map = {{kCommon, "Common"},
|
||||
{kOutputOnly, "OutputOnly"},
|
||||
{kWorkspace, "Workspace"},
|
||||
|
@ -641,7 +643,7 @@ void Somas::ComputeConflictPairs() {
|
|||
MS_LOG(INFO) << "End Preprocessing Conflicts";
|
||||
|
||||
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
|
||||
|
||||
auto start_conflict = std::chrono::system_clock::now();
|
||||
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
|
||||
|
||||
// Loop to add edges within each stream (node order within stream)
|
||||
|
@ -708,76 +710,107 @@ void Somas::ComputeConflictPairs() {
|
|||
MS_LOG(INFO) << "Start Tensor Relation Computing";
|
||||
count = tensors_list_.back()->GetId() + 1;
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
tensor_relation.emplace_back(count);
|
||||
reuse_matrix_.emplace_back(count);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tensors_list_.size(); i++) {
|
||||
auto t0 = tensors_list_[i];
|
||||
if (t0->IsLifelong() || t0->IsRefOverlap() || t0->GetAlignedSize() == 0) {
|
||||
continue;
|
||||
if (tensors_list_.size() < kParallelComputeSizeThreshold) {
|
||||
ComputeMultiTensorConflicts(tensors_list_, tensors_list_, nodes_dependency, &reuse_matrix_);
|
||||
} else {
|
||||
MS_LOG(INFO) << "Tensor Num " << tensors_list_.size() << " is larger than " << kParallelComputeSizeThreshold;
|
||||
MS_LOG(INFO) << "Enter Multi-Thread Mode...";
|
||||
size_t process_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
MS_LOG(INFO) << "Threads Num is " << process_num;
|
||||
|
||||
size_t start_index = 0;
|
||||
size_t total_size = tensors_list_.size();
|
||||
size_t job_size = total_size / process_num;
|
||||
if (job_size == 0) {
|
||||
job_size = total_size;
|
||||
}
|
||||
size_t t0_src_node = t0->GetSourceNode()->GetId();
|
||||
for (size_t j = i + 1; j < tensors_list_.size(); j++) {
|
||||
auto t1 = tensors_list_[j];
|
||||
|
||||
if (t0 == t1 || t1->IsLifelong() || t1->IsRefOverlap() || t1->GetAlignedSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t t1_src_node = t1->GetSourceNode()->GetId();
|
||||
if (t0_src_node == t1_src_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool reuse = true;
|
||||
bool all_dst_depend = false;
|
||||
// check t0's all consumers is t1's source node's dependency or not
|
||||
for (const auto &dst_node : t0->destinations_) {
|
||||
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) {
|
||||
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1
|
||||
// produced
|
||||
reuse = false;
|
||||
all_dst_depend = false;
|
||||
break;
|
||||
} else if (t1_src_node == dst_node->GetId()) {
|
||||
// t0 is t1's source node's input, can't reuse
|
||||
reuse = false;
|
||||
all_dst_depend = true;
|
||||
break;
|
||||
} else {
|
||||
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced
|
||||
reuse = true;
|
||||
all_dst_depend = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (all_dst_depend == false) {
|
||||
// check t1's all consumers is t0's source node's dependency or not
|
||||
reuse = true;
|
||||
for (const auto &dst_node : t1->destinations_) {
|
||||
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) {
|
||||
reuse = false;
|
||||
all_dst_depend = false;
|
||||
break;
|
||||
} else if (t0_src_node == dst_node->GetId()) {
|
||||
reuse = false;
|
||||
all_dst_depend = true;
|
||||
break;
|
||||
} else {
|
||||
reuse = true;
|
||||
all_dst_depend = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (all_dst_depend == true && reuse == true) {
|
||||
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId());
|
||||
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId());
|
||||
}
|
||||
std::vector<common::Task> tasks;
|
||||
while (start_index < total_size) {
|
||||
size_t end_index = (start_index + job_size) > total_size ? total_size : start_index + job_size;
|
||||
auto jobs = std::vector<SomasTensorPtr>(tensors_list_.begin() + start_index, tensors_list_.begin() + end_index);
|
||||
auto task = [this, jobs, &nodes_dependency]() {
|
||||
this->ComputeMultiTensorConflicts(jobs, tensors_list_, nodes_dependency, &reuse_matrix_);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(task);
|
||||
start_index += job_size;
|
||||
}
|
||||
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
MS_LOG(INFO) << "End Tensor Relation Computing";
|
||||
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)";
|
||||
auto end_conflict = std::chrono::system_clock::now();
|
||||
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)(time taken "
|
||||
<< std::chrono::duration_cast<std::chrono::milliseconds>(end_conflict - start_conflict).count() << "ms)";
|
||||
}
|
||||
|
||||
void Somas::ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
|
||||
const std::vector<SomasTensorPtr> &all_tensors_list,
|
||||
const vector<DynamicBitSet> &nodes_dependency,
|
||||
std::vector<DynamicBitSet> *tensor_relation) const {
|
||||
auto start = std::chrono::system_clock::now();
|
||||
MS_LOG(INFO) << "Start Computing Conflicts Pairs, tensors list size is " << calc_tensors_list.size();
|
||||
for (size_t i = 0; i < calc_tensors_list.size(); i++) {
|
||||
auto calc_tensor = calc_tensors_list[i];
|
||||
if (calc_tensor->IsLifelong() || calc_tensor->IsRefOverlap() || calc_tensor->GetAlignedSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ComputeOneTensorConflicts(calc_tensor, all_tensors_list, nodes_dependency, tensor_relation);
|
||||
}
|
||||
auto end = std::chrono::system_clock::now();
|
||||
MS_LOG(INFO) << "End Computing Conflicts Pairs (time taken "
|
||||
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms)";
|
||||
}
|
||||
|
||||
void Somas::ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
|
||||
const std::vector<SomasTensorPtr> &all_tensors_list,
|
||||
const vector<DynamicBitSet> &nodes_dependency,
|
||||
std::vector<DynamicBitSet> *tensor_relation) const {
|
||||
for (size_t j = 0; j < all_tensors_list.size(); j++) {
|
||||
auto target_tensor = all_tensors_list[j];
|
||||
if (calc_tensor == target_tensor || target_tensor->IsLifelong() || target_tensor->IsRefOverlap() ||
|
||||
target_tensor->GetAlignedSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
size_t calc_src_node = calc_tensor->GetSourceNode()->GetId();
|
||||
size_t target_src_node = target_tensor->GetSourceNode()->GetId();
|
||||
if (calc_src_node == target_src_node) {
|
||||
continue;
|
||||
}
|
||||
if ((*tensor_relation)[calc_tensor->GetId()].IsBitTrue(target_tensor->GetId()) ||
|
||||
(*tensor_relation)[target_tensor->GetId()].IsBitTrue(calc_tensor->GetId())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool reuse = true;
|
||||
// check calc_tensor's all consumers is target_tensor's source node's dependency or not
|
||||
for (const auto &dst_node : calc_tensor->destinations_) {
|
||||
if (nodes_dependency[target_src_node].IsBitTrue(dst_node->GetId()) == false) {
|
||||
// calc_tensor's consumer is not in target_tensor's source node's dependency, not sure this consumer is done or
|
||||
// not when target_tensor produced
|
||||
reuse = false;
|
||||
break;
|
||||
} else if (target_src_node == dst_node->GetId()) {
|
||||
// calc_tensor is target_tensor's source node's input, can't reuse
|
||||
reuse = false;
|
||||
break;
|
||||
} else {
|
||||
// calc_tensor's consumer is in target_tensor's source node's dependency, this consumer is done when
|
||||
// target_tensor produced
|
||||
reuse = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (reuse) {
|
||||
// calc_tensor and target_tensor have dependencies so they can reuse each other
|
||||
(*tensor_relation)[calc_tensor->GetId()].SetBitTrue(target_tensor->GetId());
|
||||
(*tensor_relation)[target_tensor->GetId()].SetBitTrue(calc_tensor->GetId());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); }
|
||||
|
@ -798,13 +831,13 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
// Keep all constraints for first tensor in list
|
||||
size_t tid_0 = ref_node_list[0];
|
||||
for (SomasTensorPtr tensor : tensors_list_) {
|
||||
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) {
|
||||
if (reuse_matrix_[tid_0].IsBitTrue(tensor->GetId()) == false) {
|
||||
continue;
|
||||
}
|
||||
for (size_t tid : ref_node_list) {
|
||||
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) {
|
||||
tensor_relation[tid_0].SetBitFalse(tensor->GetId());
|
||||
tensor_relation[tensor->GetId()].SetBitFalse(tid_0);
|
||||
if (reuse_matrix_[tid].IsBitTrue(tensor->GetId()) == false) {
|
||||
reuse_matrix_[tid_0].SetBitFalse(tensor->GetId());
|
||||
reuse_matrix_[tensor->GetId()].SetBitFalse(tid_0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -924,23 +957,21 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
for (auto ref_overlap_list : ref_overlap_constraints_) {
|
||||
for (size_t tid_1 : ref_overlap_list) {
|
||||
for (size_t tid_2 : ref_overlap_list) {
|
||||
tensor_relation[tid_1].SetBitTrue(tid_2);
|
||||
tensor_relation[tid_2].SetBitTrue(tid_1);
|
||||
reuse_matrix_[tid_1].SetBitTrue(tid_2);
|
||||
reuse_matrix_[tid_2].SetBitTrue(tid_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End Solving Preprocessing for Ref Overlap";
|
||||
|
||||
#ifdef SOMAS_DEBUG
|
||||
// Compute number of constraints for each tensor
|
||||
auto tensors_num = tensors_list_.size();
|
||||
for (auto tensor1 : tensors_list_) {
|
||||
size_t count_constraints = 0;
|
||||
for (auto tensor2 : tensors_list_) {
|
||||
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) {
|
||||
count_constraints++;
|
||||
}
|
||||
}
|
||||
tensor1->num_constraints_ = count_constraints;
|
||||
auto ones_num = reuse_matrix_[tensor1->GetId()].CountOnesNum();
|
||||
tensor1->num_constraints_ = tensors_num - ones_num;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Prepare solver info
|
||||
MS_LOG(INFO) << "Start Loop to create solver info";
|
||||
|
@ -960,7 +991,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
somas_solver_ = std::make_shared<SomasSolverPre>();
|
||||
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation,
|
||||
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &reuse_matrix_,
|
||||
contiguous_tensors_list_removed_ref, false);
|
||||
MS_LOG(INFO) << "End Solving";
|
||||
if (status != SUCCESS) {
|
||||
|
|
|
@ -53,7 +53,7 @@ class Somas {
|
|||
void DumpSomasMemoryIR(const string filename);
|
||||
|
||||
static bool NodeSort(SomasNodePtr, SomasNodePtr);
|
||||
std::vector<DynamicBitSet> tensor_relation;
|
||||
std::vector<DynamicBitSet> reuse_matrix_;
|
||||
|
||||
private:
|
||||
// Maps
|
||||
|
@ -128,6 +128,14 @@ class Somas {
|
|||
SomasParameterPtr CreateSomasParameters(AnfNodePtr node, size_t index);
|
||||
void InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel);
|
||||
void InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel);
|
||||
void ComputeOneTensorConflicts(const std::shared_ptr<SomasTensor> &calc_tensor,
|
||||
const std::vector<SomasTensorPtr> &all_tensors_list,
|
||||
const vector<DynamicBitSet> &nodes_dependency,
|
||||
std::vector<DynamicBitSet> *tensor_relation) const;
|
||||
void ComputeMultiTensorConflicts(const std::vector<SomasTensorPtr> &calc_tensors_list,
|
||||
const std::vector<SomasTensorPtr> &all_tensors_list,
|
||||
const vector<DynamicBitSet> &nodes_dependency,
|
||||
std::vector<DynamicBitSet> *tensor_relation) const;
|
||||
};
|
||||
|
||||
using SomasPtr = std::shared_ptr<Somas>;
|
||||
|
|
|
@ -92,6 +92,24 @@ class DynamicBitSet {
|
|||
|
||||
bool IsBitTrue(size_t index) const { return (bit_[GetIndex(index)] & GetBitMask(index)) != 0x0; }
|
||||
|
||||
size_t CountOnesNum() const {
|
||||
size_t ret = 0;
|
||||
static char ones_num_in_hex[] = "\0\1\1\2\1\2\2\3\1\2\2\3\2\3\3\4";
|
||||
for (size_t i = 0; i < bit_size_; i++) {
|
||||
auto value = bit_[i];
|
||||
if (value == 0) {
|
||||
continue;
|
||||
}
|
||||
char *char_value = reinterpret_cast<char *>(&value);
|
||||
for (size_t j = 0; j < bit_width_ / CHAR_BIT; j++) {
|
||||
ret += ones_num_in_hex[char_value[j] & 0xF];
|
||||
char_value[j] >>= 4;
|
||||
ret += ones_num_in_hex[char_value[j] & 0xF];
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Log() {
|
||||
std::cout << "Start Print Bitset ";
|
||||
for (size_t i = 0; i < bit_size_; i++) {
|
||||
|
|
Loading…
Reference in New Issue