forked from mindspore-Ecosystem/mindspore
!9521 [MEMOPT] Use bitset model to optimize somas
From: @laiyongqiang Reviewed-by: @jjfeing,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
fee3aee30e
|
@ -122,6 +122,7 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
node->scope_full_name_ = kernel->fullname_with_scope();
|
||||
nodes_list_.push_back(node);
|
||||
stream->nodes_.push_back(node);
|
||||
auto key = kernel.get();
|
||||
nodes_map_[key] = node;
|
||||
node_index++;
|
||||
|
@ -565,54 +566,6 @@ void Somas::PreprocessingConflicts() {
|
|||
}
|
||||
}
|
||||
|
||||
static bool LifetimeOverlap(lifetime_t lifetime1, lifetime_t lifetime2) {
|
||||
size_t start1 = std::min(lifetime1.start_, lifetime1.end_);
|
||||
size_t end1 = std::max(lifetime1.start_, lifetime1.end_);
|
||||
size_t start2 = std::min(lifetime2.start_, lifetime2.end_);
|
||||
size_t end2 = std::max(lifetime2.start_, lifetime2.end_);
|
||||
return (std::max(end1, end2) - std::min(start1, start2) <= end2 - start2 + end1 - start1);
|
||||
}
|
||||
|
||||
static bool Subset(std::set<SomasStreamPtr> streamSet1, std::set<SomasStreamPtr> streamSet2) {
|
||||
for (auto stream : streamSet1) {
|
||||
if (streamSet2.count(stream) == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void EraseSet(std::set<SomasStreamPtr> *streamSet, std::set<SomasStreamPtr> removeStreamsSet) {
|
||||
for (auto stream : removeStreamsSet) {
|
||||
streamSet->erase(stream);
|
||||
}
|
||||
}
|
||||
|
||||
static bool ValidSubset(std::set<SomasStreamPtr> destStreams, std::set<SomasStreamPtr> ancestorsAndSelf,
|
||||
SomasTensorPtr ancestorTensor, SomasTensorPtr tensor) {
|
||||
MS_EXCEPTION_IF_NULL(ancestorTensor);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
for (auto stream : destStreams) {
|
||||
if (ancestorsAndSelf.count(stream) == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stream != tensor->GetSourceStream()) {
|
||||
MS_EXCEPTION_IF_NULL(tensor->GetSourceStream());
|
||||
if (tensor->GetSourceStream()->ancestor_streams_group_.count(stream) == 0 &&
|
||||
ancestorTensor->max_destination_id_[stream] >
|
||||
tensor->GetSourceNode()->anc_stream_max_order_[stream->GetId()]) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (ancestorTensor->max_destination_id_[stream] >= tensor->lifetime_.start_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Somas::ComputeConflictPairs() {
|
||||
if (tensors_list_.empty()) {
|
||||
MS_LOG(INFO) << "No Tensor for Conflict computing";
|
||||
|
@ -623,162 +576,145 @@ void Somas::ComputeConflictPairs() {
|
|||
PreprocessingConflicts();
|
||||
MS_LOG(INFO) << "End Preprocessing Conflicts";
|
||||
|
||||
MS_LOG(INFO) << "Start Array Initialization";
|
||||
cannot_reuse_ =
|
||||
std::make_shared<Array>(tensors_list_.back()->GetId() + 1,
|
||||
tensors_list_.back()->GetId() + 1); // array size is (max_id + 1) x (max_id + 1)
|
||||
MS_LOG(INFO) << "End Array Initialization";
|
||||
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
|
||||
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
|
||||
|
||||
MS_LOG(INFO) << "Start Conflict Computing";
|
||||
// Loop to add edges within each stream (node order within stream)
|
||||
for (const auto &stream : streams_list_) {
|
||||
auto &nodes = stream->nodes_;
|
||||
std::sort(nodes.begin(), nodes.end(), NodeSort);
|
||||
for (size_t i = 1; i < nodes.size(); i++) {
|
||||
const auto &previous_node = nodes[i - 1];
|
||||
const auto ¤t_node = nodes[i];
|
||||
current_node->ancestor_nodes_.insert(previous_node);
|
||||
}
|
||||
}
|
||||
|
||||
size_t count_reuse = 0;
|
||||
// Loop to add edges from end to beginning of next group
|
||||
for (const auto &group : streams_groups_) {
|
||||
for (size_t i = 1; i < group.size(); i++) {
|
||||
int64_t previous_stream = group[i - 1];
|
||||
int64_t current_stream = group[i];
|
||||
|
||||
// Loop for ancestor stream groups reuse
|
||||
for (auto stream : streams_list_) {
|
||||
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_group_;
|
||||
auto it =
|
||||
std::find_if(streams_list_.begin(), streams_list_.end(),
|
||||
[previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; });
|
||||
if (it == streams_list_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &last_node_in_prev_stream = (*it)->nodes_.back();
|
||||
|
||||
std::set<SomasStreamPtr> ancestors_and_self = ancestors;
|
||||
ancestors_and_self.insert(stream);
|
||||
it = std::find_if(streams_list_.begin(), streams_list_.end(),
|
||||
[current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; });
|
||||
if (it == streams_list_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &first_node_in_cur_stream = (*it)->nodes_.front();
|
||||
|
||||
for (auto ancestor_stream : ancestors) {
|
||||
for (auto ancestor_tensor : ancestor_stream->tensors_) {
|
||||
if (ancestor_tensor->GetAlignedSize() == 0) continue;
|
||||
if (ancestor_tensor->IsLifelong()) continue;
|
||||
if (ancestor_tensor->IsSemiLifelongEnd()) continue;
|
||||
if (ancestor_tensor->IsRefOverlap()) continue;
|
||||
first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream);
|
||||
}
|
||||
}
|
||||
|
||||
if (!ancestor_tensor->IsBetweenStreams() || Subset(ancestor_tensor->destinationStreams_, ancestors)) {
|
||||
for (auto tensor : stream->tensors_) {
|
||||
if (tensor->IsGap()) continue;
|
||||
if (tensor->GetAlignedSize() == 0) continue;
|
||||
if (tensor->IsLifelong()) continue;
|
||||
if (tensor->IsSemiLifelongStart()) continue;
|
||||
if (tensor->IsRefOverlap()) continue;
|
||||
// Loop to avoid tensors with empty destinations (add itself)
|
||||
for (const auto &tensor : tensors_list_) {
|
||||
if (tensor->destinations_.size() == 0) {
|
||||
tensor->destinations_.insert(tensor->GetSourceNode());
|
||||
}
|
||||
}
|
||||
|
||||
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
|
||||
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
|
||||
count_reuse++;
|
||||
}
|
||||
MS_LOG(INFO) << "Start Bitset";
|
||||
std::vector<DynamicBitSet> nodes_dependency;
|
||||
|
||||
size_t count = nodes_list_.back()->GetId() + 1;
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
nodes_dependency.emplace_back(count);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Start Path Computing";
|
||||
// Loop to compute ancestor paths via bitset for time dependence
|
||||
for (const auto &node : nodes_list_) {
|
||||
for (const auto &ancestor : node->ancestor_nodes_) {
|
||||
nodes_dependency[node->GetId()].SetBitTrue(ancestor->GetId());
|
||||
Union(&nodes_dependency[node->GetId()], &nodes_dependency[ancestor->GetId()]);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End Path Computing";
|
||||
|
||||
MS_LOG(INFO) << "Start Tensor Relation Computing";
|
||||
count = tensors_list_.back()->GetId() + 1;
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
tensor_relation.emplace_back(count);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tensors_list_.size(); i++) {
|
||||
for (size_t j = i + 1; j < tensors_list_.size(); j++) {
|
||||
auto t0 = tensors_list_[i];
|
||||
auto t1 = tensors_list_[j];
|
||||
|
||||
if (t0 == t1 || t0->IsGap() || t1->IsGap() || t0->IsLifelong() || t1->IsLifelong() || t0->IsRefOverlap() ||
|
||||
t1->IsRefOverlap() || t0->GetAlignedSize() == 0 || t1->GetAlignedSize() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t t0_src_node = t0->GetSourceNode()->GetId();
|
||||
size_t t1_src_node = t1->GetSourceNode()->GetId();
|
||||
if (t0_src_node == t1_src_node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool reuse = true;
|
||||
bool all_dst_depend = false;
|
||||
// check t0's all consumers is t1's source node's dependency or not
|
||||
for (const auto &dst_node : t0->destinations_) {
|
||||
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) {
|
||||
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1
|
||||
// produced
|
||||
reuse = false;
|
||||
all_dst_depend = false;
|
||||
break;
|
||||
} else if (t1_src_node == dst_node->GetId()) {
|
||||
// t0 is t1's source node's input, can't reuse
|
||||
reuse = false;
|
||||
all_dst_depend = true;
|
||||
break;
|
||||
} else {
|
||||
for (auto tensor : stream->tensors_) {
|
||||
if (Subset(ancestor_tensor->destinationStreams_, ancestors_and_self) &&
|
||||
ancestor_tensor->max_destination_id_[tensor->GetSourceStream()] < tensor->lifetime_.start_) {
|
||||
if (tensor->IsGap()) continue;
|
||||
if (tensor->GetAlignedSize() == 0) continue;
|
||||
if (tensor->IsLifelong()) continue;
|
||||
if (tensor->IsSemiLifelongStart()) continue;
|
||||
if (tensor->IsRefOverlap()) continue;
|
||||
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced
|
||||
reuse = true;
|
||||
all_dst_depend = true;
|
||||
}
|
||||
}
|
||||
|
||||
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
|
||||
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
|
||||
count_reuse++;
|
||||
}
|
||||
if (all_dst_depend == false) {
|
||||
// check t1's all consumers is t0's source node's dependency or not
|
||||
reuse = true;
|
||||
for (const auto &dst_node : t1->destinations_) {
|
||||
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) {
|
||||
reuse = false;
|
||||
all_dst_depend = false;
|
||||
break;
|
||||
} else if (t0_src_node == dst_node->GetId()) {
|
||||
reuse = false;
|
||||
all_dst_depend = true;
|
||||
break;
|
||||
} else {
|
||||
reuse = true;
|
||||
all_dst_depend = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Loop for ancestor streams (no groups)
|
||||
for (auto stream : streams_list_) {
|
||||
auto ancestors_no_groups = stream->ancestor_streams_;
|
||||
EraseSet(&ancestors_no_groups, stream->ancestor_streams_group_);
|
||||
|
||||
for (auto ancestor_stream : ancestors_no_groups) {
|
||||
for (auto ancestor_tensor : ancestor_stream->tensors_) {
|
||||
if (ancestor_tensor->GetAlignedSize() == 0) continue;
|
||||
if (ancestor_tensor->IsLifelong()) continue;
|
||||
if (ancestor_tensor->IsSemiLifelongEnd()) continue;
|
||||
if (ancestor_tensor->IsRefOverlap()) continue;
|
||||
|
||||
if (!ancestor_tensor->IsBetweenStreams()) {
|
||||
for (auto tensor : stream->tensors_) {
|
||||
if (tensor->IsGap()) continue;
|
||||
if (tensor->GetAlignedSize() == 0) continue;
|
||||
if (tensor->IsLifelong()) continue;
|
||||
if (tensor->IsSemiLifelongStart()) continue;
|
||||
if (tensor->IsRefOverlap()) continue;
|
||||
|
||||
size_t max_ancestor_order = tensor->GetSourceNode()->anc_stream_max_order_[ancestor_stream->GetId()];
|
||||
|
||||
if (ancestor_tensor->lifetime_.end_ <= max_ancestor_order) {
|
||||
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
|
||||
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
|
||||
count_reuse++;
|
||||
}
|
||||
}
|
||||
} else { // ancestor tensor goes to another stream (might go to same stream also)
|
||||
std::set<SomasStreamPtr> dest_streams = ancestor_tensor->destinationStreams_;
|
||||
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_;
|
||||
|
||||
std::set<SomasStreamPtr> ancestors_and_self = ancestors;
|
||||
ancestors_and_self.insert(stream);
|
||||
|
||||
for (auto tensor : stream->tensors_) {
|
||||
if (tensor->IsGap()) continue;
|
||||
if (tensor->GetAlignedSize() == 0) continue;
|
||||
if (tensor->IsLifelong()) continue;
|
||||
if (tensor->IsSemiLifelongStart()) continue;
|
||||
if (tensor->IsRefOverlap()) continue;
|
||||
|
||||
if (ValidSubset(dest_streams, ancestors_and_self, ancestor_tensor, tensor)) {
|
||||
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
|
||||
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
|
||||
count_reuse++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (all_dst_depend == true && reuse == true) {
|
||||
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId());
|
||||
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Loop for same stream
|
||||
for (auto stream : streams_list_) {
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
for (auto tensor1 : stream->tensors_) {
|
||||
if (tensor1->GetAlignedSize() == 0) continue;
|
||||
if (tensor1->IsLifelong()) continue;
|
||||
if (tensor1->IsRefOverlap()) continue;
|
||||
|
||||
for (auto tensor2 : stream->tensors_) {
|
||||
if (tensor2->GetId() >= tensor1->GetId())
|
||||
break; // keep only when tensors kept sorted in tensors-vector of each stream, otherwise remove
|
||||
|
||||
if (tensor2->GetAlignedSize() == 0) continue;
|
||||
if (tensor2->IsLifelong()) continue;
|
||||
if (tensor2->IsRefOverlap()) continue;
|
||||
|
||||
// Between streams extra safety
|
||||
if (tensor1->IsBetweenStreams() && tensor2->IsBetweenStreams()) continue;
|
||||
|
||||
// Check lifetime overlap
|
||||
lifetime_t lifetime1 = tensor1->lifetime_;
|
||||
lifetime_t lifetime2 = tensor2->lifetime_;
|
||||
|
||||
if (!LifetimeOverlap(lifetime1, lifetime2)) {
|
||||
// Between-streams extra safety
|
||||
if (tensor1->IsBetweenStreams() && lifetime1.end_ < lifetime2.start_) continue;
|
||||
if (tensor2->IsBetweenStreams() && lifetime2.end_ < lifetime1.start_) continue;
|
||||
|
||||
// Semi-lifelong extra safety
|
||||
if (lifetime1.end_ < lifetime2.start_ && (tensor2->IsSemiLifelongStart() || tensor1->IsSemiLifelongEnd()))
|
||||
continue;
|
||||
if (lifetime2.end_ < lifetime1.start_ && (tensor1->IsSemiLifelongStart() || tensor2->IsSemiLifelongEnd()))
|
||||
continue;
|
||||
|
||||
// If arrived here, allow reuse
|
||||
(*cannot_reuse_)(tensor2->GetId(), tensor1->GetId()) = 0;
|
||||
(*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) = 0;
|
||||
count_reuse++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "End Conflict Computing";
|
||||
MS_LOG(INFO) << "Found " << count_reuse << " tensor pairs of allowed reusability";
|
||||
MS_LOG(INFO) << "End Tensor Relation Computing";
|
||||
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)";
|
||||
}
|
||||
|
||||
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); }
|
||||
|
||||
bool Somas::Assign(const session::KernelGraph *graph) {
|
||||
if (tensors_list_.empty()) {
|
||||
MS_LOG(INFO) << "No Tensor for Assigner";
|
||||
|
@ -795,13 +731,13 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
// Keep all constraints for first tensor in list
|
||||
size_t tid_0 = ref_node_list[0];
|
||||
for (SomasTensorPtr tensor : tensors_list_) {
|
||||
if ((*cannot_reuse_)(tid_0, tensor->GetId()) == 1) {
|
||||
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) {
|
||||
continue;
|
||||
}
|
||||
for (size_t tid : ref_node_list) {
|
||||
if ((*cannot_reuse_)(tid, tensor->GetId()) == 1) {
|
||||
(*cannot_reuse_)(tid_0, tensor->GetId()) = 1;
|
||||
(*cannot_reuse_)(tensor->GetId(), tid_0) = 1;
|
||||
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) {
|
||||
tensor_relation[tid_0].SetBitFalse(tensor->GetId());
|
||||
tensor_relation[tensor->GetId()].SetBitFalse(tid_0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -921,8 +857,8 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
for (auto ref_overlap_list : ref_overlap_constraints_) {
|
||||
for (size_t tid_1 : ref_overlap_list) {
|
||||
for (size_t tid_2 : ref_overlap_list) {
|
||||
(*cannot_reuse_)(tid_1, tid_2) = 0;
|
||||
(*cannot_reuse_)(tid_2, tid_1) = 0;
|
||||
tensor_relation[tid_1].SetBitTrue(tid_2);
|
||||
tensor_relation[tid_2].SetBitTrue(tid_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -932,7 +868,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
for (auto tensor1 : tensors_list_) {
|
||||
size_t count_constraints = 0;
|
||||
for (auto tensor2 : tensors_list_) {
|
||||
if ((*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) == 1) {
|
||||
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) {
|
||||
count_constraints++;
|
||||
}
|
||||
}
|
||||
|
@ -943,7 +879,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "Start Contiguous Gaps Preprocessing";
|
||||
for (auto contiguous_list : contiguous_tensors_list_) {
|
||||
if (contiguous_list.size() < 3) {
|
||||
MS_LOG(ERROR) << "contiguous_list should has at least one input and two gap, now it is "
|
||||
MS_LOG(ERROR) << "contiguous_list should have at least one input and two gap, now it is "
|
||||
<< contiguous_list.size();
|
||||
}
|
||||
size_t front_gap_id = contiguous_list[0];
|
||||
|
@ -959,10 +895,20 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
size_t back_neighbour_id = contiguous_list[contiguous_list.size() - 2];
|
||||
for (SomasTensorPtr tensor : tensors_list_) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
(*cannot_reuse_)(tensor->GetId(), front_gap_id) = (*cannot_reuse_)(tensor->GetId(), front_neighbour_id);
|
||||
(*cannot_reuse_)(front_gap_id, tensor->GetId()) = (*cannot_reuse_)(front_neighbour_id, tensor->GetId());
|
||||
(*cannot_reuse_)(tensor->GetId(), back_gap_id) = (*cannot_reuse_)(tensor->GetId(), back_neighbour_id);
|
||||
(*cannot_reuse_)(back_gap_id, tensor->GetId()) = (*cannot_reuse_)(back_neighbour_id, tensor->GetId());
|
||||
if (tensor_relation[tensor->GetId()].IsBitTrue(front_neighbour_id) == false) {
|
||||
tensor_relation[tensor->GetId()].SetBitFalse(front_gap_id);
|
||||
tensor_relation[front_gap_id].SetBitFalse(tensor->GetId());
|
||||
} else {
|
||||
tensor_relation[tensor->GetId()].SetBitTrue(front_gap_id);
|
||||
tensor_relation[front_gap_id].SetBitTrue(tensor->GetId());
|
||||
}
|
||||
if (tensor_relation[tensor->GetId()].IsBitTrue(back_neighbour_id) == false) {
|
||||
tensor_relation[tensor->GetId()].SetBitFalse(back_gap_id);
|
||||
tensor_relation[back_gap_id].SetBitFalse(tensor->GetId());
|
||||
} else {
|
||||
tensor_relation[tensor->GetId()].SetBitTrue(back_gap_id);
|
||||
tensor_relation[back_gap_id].SetBitTrue(tensor->GetId());
|
||||
}
|
||||
}
|
||||
SomasTensorPtr front_neighbour = tensors_map_[front_neighbour_id];
|
||||
SomasTensorPtr back_neighbour = tensors_map_[back_neighbour_id];
|
||||
|
@ -995,8 +941,8 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
somas_solver_ = std::make_shared<SomasSolverPre>();
|
||||
auto status =
|
||||
somas_solver_->Solving(graph, &solver_tensor_desc_list_, cannot_reuse_, contiguous_tensors_list_removed_ref, false);
|
||||
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation,
|
||||
contiguous_tensors_list_removed_ref, false);
|
||||
MS_LOG(INFO) << "End Solving";
|
||||
if (status != SUCCESS) {
|
||||
GenStatisticInfo();
|
||||
|
|
|
@ -51,6 +51,9 @@ class Somas {
|
|||
void DumpSomasBasicIR(const string filename);
|
||||
void DumpSomasMemoryIR(const string filename);
|
||||
|
||||
static bool NodeSort(SomasNodePtr, SomasNodePtr);
|
||||
std::vector<DynamicBitSet> tensor_relation;
|
||||
|
||||
private:
|
||||
// Maps
|
||||
std::unordered_map<size_t, SomasTensorPtr> tensors_map_;
|
||||
|
@ -68,9 +71,6 @@ class Somas {
|
|||
std::unordered_map<size_t, SomasSolverTensorDescPtr> solver_tensor_desc_list_;
|
||||
SomasSolverPrePtr somas_solver_;
|
||||
|
||||
// Constraints
|
||||
std::shared_ptr<Array> cannot_reuse_;
|
||||
|
||||
// Contiguous list
|
||||
std::vector<vector<size_t>> contiguous_tensors_list_;
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ class SomasNode {
|
|||
// Public attributes (mutated in code)
|
||||
std::string scope_full_name_;
|
||||
|
||||
std::set<SomasNodePtr>
|
||||
ancestor_nodes_; // keeping only distance *one* ancestor nodes; enough to ComputeAncestorNodes()
|
||||
// node's dependency including data dependency and time dependency
|
||||
std::set<SomasNodePtr> ancestor_nodes_;
|
||||
std::set<SomasTensorPtr> tensors_;
|
||||
|
||||
std::vector<SomasTensorPtr> input_tensors_;
|
||||
|
|
|
@ -117,15 +117,15 @@ void FootPrint::Merge(vector<Interval> *interval_v, stack<Interval> *s) {
|
|||
|
||||
return;
|
||||
}
|
||||
void FootPrint::ConstrainedBLocks(const std::shared_ptr<Array> &constraints, const BlockTensor &b1,
|
||||
const BlockTensor &b2, vector<Interval> *oInterval) {
|
||||
void FootPrint::ConstrainedBLocks(std::vector<DynamicBitSet> *constraints, const BlockTensor &b1, const BlockTensor &b2,
|
||||
vector<Interval> *oInterval) {
|
||||
MS_EXCEPTION_IF_NULL(oInterval);
|
||||
// propagate
|
||||
size_t acum = m_offset_;
|
||||
|
||||
for (SomasSolverTensorDescPtr p1 = b1.m_start_tensor_; NULL != p1; p1 = p1->right_) {
|
||||
for (SomasSolverTensorDescPtr p2 = b2.m_start_tensor_; NULL != p2; p2 = p2->right_) {
|
||||
if ((*constraints)(p1->index_, p2->index_) == 1) {
|
||||
if ((*constraints)[p1->index_].IsBitTrue(p2->index_) == false) {
|
||||
Interval a = Interval(acum, acum + p1->size_);
|
||||
Interval b = Interval(p2);
|
||||
if (a.lb() < b.ub()) {
|
||||
|
@ -136,7 +136,7 @@ void FootPrint::ConstrainedBLocks(const std::shared_ptr<Array> &constraints, con
|
|||
acum += p1->size_;
|
||||
}
|
||||
}
|
||||
bool FootPrint::findOffset(const std::shared_ptr<Array> &constraints, const BlockTensor &block, size_t *offset) {
|
||||
bool FootPrint::findOffset(std::vector<DynamicBitSet> *constraints, const BlockTensor &block, size_t *offset) {
|
||||
MS_EXCEPTION_IF_NULL(offset);
|
||||
bool bretval = true;
|
||||
vector<Interval> l_interval;
|
||||
|
@ -150,7 +150,7 @@ bool FootPrint::findOffset(const std::shared_ptr<Array> &constraints, const Bloc
|
|||
// transform constrained tensors in non eligible intervals
|
||||
for (size_t i = 0; i < m_starts_.size(); i++) {
|
||||
if (block.Alone() && m_starts_[i]->Alone() &&
|
||||
(*constraints)(block.m_start_tensor_->index_, m_starts_[i]->m_start_tensor_->index_)) {
|
||||
(*constraints)[block.m_start_tensor_->index_].IsBitTrue(m_starts_[i]->m_start_tensor_->index_) == false) {
|
||||
if (m_algorithm_ != 1 && i == 0) return false;
|
||||
Interval It = Interval(m_starts_[i]->m_start_tensor_);
|
||||
l_interval.emplace_back(It);
|
||||
|
@ -201,7 +201,7 @@ void FootPrint::printStats() {
|
|||
MS_LOG(DEBUG) << "Footprint blocks: " << m_starts_.size() << " \toffset: " << m_offset_;
|
||||
}
|
||||
bool FastHeuristic::Eval(vector<BlockTensor> *block_tensors_v, std::shared_ptr<FootPrint> foot_print,
|
||||
const std::shared_ptr<Array> &pConstraints) {
|
||||
std::vector<DynamicBitSet> *pConstraints) {
|
||||
MS_EXCEPTION_IF_NULL(foot_print);
|
||||
auto start = std::chrono::system_clock::now();
|
||||
|
||||
|
|
|
@ -143,8 +143,8 @@ class FootPrint : public std::enable_shared_from_this<FootPrint> {
|
|||
void Destroy();
|
||||
const size_t getOffset() { return m_offset_; }
|
||||
void setOffset(const size_t &offset) { m_offset_ = offset; }
|
||||
bool findOffset(const std::shared_ptr<Array> &constraints, const BlockTensor &block, size_t *offset);
|
||||
void ConstrainedBLocks(const std::shared_ptr<Array> &constraints, const BlockTensor &b1, const BlockTensor &b2,
|
||||
bool findOffset(std::vector<DynamicBitSet> *constraints, const BlockTensor &block, size_t *offset);
|
||||
void ConstrainedBLocks(std::vector<DynamicBitSet> *constraints, const BlockTensor &b1, const BlockTensor &b2,
|
||||
vector<Interval> *oInterval_l);
|
||||
void Merge(vector<Interval> *l_interval, stack<Interval> *l_merged);
|
||||
bool findFirst(stack<Interval> *merged, const BlockTensor &block, size_t *offset);
|
||||
|
@ -167,7 +167,7 @@ class FastHeuristic {
|
|||
void setAlignment(const size_t &a) { m_alignment_ = a; }
|
||||
void Destroy();
|
||||
bool Eval(vector<BlockTensor> *block_tensors_v, std::shared_ptr<FootPrint> foot_print,
|
||||
const std::shared_ptr<Array> &pConstraints);
|
||||
std::vector<DynamicBitSet> *pConstraints);
|
||||
|
||||
private:
|
||||
size_t m_alignment_;
|
||||
|
|
|
@ -169,16 +169,14 @@ bool SomasSolverCore::Verify(const size_t &upperbound) {
|
|||
t2 = t2_.second;
|
||||
if (t1->index_ == t2->index_) continue;
|
||||
bool blifelong = (t1->lifelong_ || t2->lifelong_) && (t1->index_ != t2->index_);
|
||||
const size_t continuous = 2;
|
||||
const size_t conflict = 1;
|
||||
if ((*constraints_)(t1->index_, t2->index_) == continuous) { // continuous constraint
|
||||
// t1 must be continous to t2
|
||||
if (t2->right_ == t1) { // continuous constraint
|
||||
// t1 must be continuous to t2
|
||||
bool bcontinuous = t1->offset_ == (t2->offset_ + t2->size_);
|
||||
if (!bcontinuous) {
|
||||
MS_LOG(WARNING) << "Continuous constraint violation in tensors " << t1->index_ << " and" << t2->index_;
|
||||
retval = false;
|
||||
}
|
||||
} else if (blifelong || (*constraints_)(t1->index_, t2->index_) == conflict) { // conflict constraint
|
||||
} else if (blifelong || constraints_[t1->index_].IsBitTrue(t2->index_) == false) { // conflict constraint
|
||||
size_t t1_ub = t1->offset_ + t1->size_;
|
||||
size_t t2_ub = t2->offset_ + t2->size_;
|
||||
bool b_overlap_lb = ((t2->offset_ >= t1->offset_) && (t2->offset_ < t1_ub));
|
||||
|
@ -336,7 +334,7 @@ size_t SomasSolverCore::Search(const std::shared_ptr<FootPrint> &pFootprint) {
|
|||
FastHeuristic fh;
|
||||
MS_LOG(INFO) << "Calling FastSolver Search for " << block_tensors_.size() << " tensors ";
|
||||
auto start = std::chrono::system_clock::now();
|
||||
if (fh.Eval(&block_tensors_, pFootprint, constraints_)) {
|
||||
if (fh.Eval(&block_tensors_, pFootprint, &constraints_)) {
|
||||
result = pFootprint->Result();
|
||||
auto end = std::chrono::system_clock::now();
|
||||
timing_ = std::chrono::duration_cast<std::chrono::milliseconds>((end - start)).count();
|
||||
|
|
|
@ -33,9 +33,9 @@ class SomasSolverCore {
|
|||
public:
|
||||
/// Interface Function: receive parameters, creates the model to solve and then save the result
|
||||
SomasSolverCore(const std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
|
||||
const std::shared_ptr<Array> &constraints)
|
||||
std::vector<DynamicBitSet> *constraints)
|
||||
: tensors_(tensors),
|
||||
constraints_(constraints),
|
||||
constraints_(*constraints),
|
||||
upperbound_(SIZE_MAX),
|
||||
timing_(0),
|
||||
lifelongmemory_(0),
|
||||
|
@ -69,7 +69,7 @@ class SomasSolverCore {
|
|||
private:
|
||||
std::unordered_map<size_t, SomasSolverTensorDescPtr> tensors_;
|
||||
vector<BlockTensor> block_tensors_;
|
||||
std::shared_ptr<Array> constraints_;
|
||||
std::vector<DynamicBitSet> constraints_;
|
||||
size_t upperbound_{0};
|
||||
size_t timing_{0};
|
||||
size_t lifelongmemory_{0};
|
||||
|
|
|
@ -27,25 +27,15 @@ namespace mindspore {
|
|||
namespace somas {
|
||||
Status SomasSolverPre::Solving(const session::KernelGraph *graph,
|
||||
std::unordered_map<size_t, SomasSolverTensorDescPtr> *ptensors,
|
||||
std::shared_ptr<Array> pConstraints, const vector<vector<size_t>> &continuous_v,
|
||||
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v,
|
||||
bool bVerifySolution, bool ball, SortingType sorting, FittingType fitting,
|
||||
AlgorithmType algorithm) {
|
||||
Status retval = SUCCESS;
|
||||
|
||||
try {
|
||||
std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors = *ptensors;
|
||||
std::unordered_map<size_t, SomasSolverTensorDescPtr>::iterator max =
|
||||
std::max_element(tensors.begin(), tensors.end(),
|
||||
[](const std::pair<size_t, SomasSolverTensorDescPtr> &a,
|
||||
const std::pair<size_t, SomasSolverTensorDescPtr> &b) { return a.first < b.first; });
|
||||
size_t maxIndex = max->first;
|
||||
if (maxIndex > pConstraints->Rows() - 1) {
|
||||
MS_LOG(WARNING) << "ERROR: MaxIndex invalid, MaxIndex " << maxIndex << ", Rows " << pConstraints->Rows();
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Filling in constraints matrix..";
|
||||
uint32_t continuous_cnt = 0;
|
||||
|
||||
// creating S Lists
|
||||
for (auto &aux : continuous_v) {
|
||||
for (uint32_t i = 0; i < aux.size() - 1; i++) {
|
||||
|
@ -60,8 +50,6 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph,
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
const size_t continuous = 2;
|
||||
(*pConstraints)(index2, index1) = continuous;
|
||||
if (tensors[index1]->right_)
|
||||
MS_LOG(WARNING) << "Warning:tensor " << index1
|
||||
<< " already has a right tensor (id: " << tensors[index1]->right_->index_;
|
||||
|
@ -104,7 +92,7 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph,
|
|||
|
||||
void SomasSolverPre::Log(const session::KernelGraph *graph,
|
||||
const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
|
||||
const std::shared_ptr<Array> &pConstraints, const vector<vector<size_t>> &continuous_v) {
|
||||
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v) {
|
||||
MS_LOG(INFO) << "SomasSolver::Log Writing somas-input.txt..";
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
@ -143,7 +131,7 @@ void SomasSolverPre::Log(const session::KernelGraph *graph,
|
|||
for (auto &t2 : tensors) {
|
||||
size_t idx1 = t1.first;
|
||||
size_t idx2 = t2.first;
|
||||
if ((idx1 != idx2) && (*pConstraints)(idx1, idx2) == 1) {
|
||||
if ((idx1 != idx2) && (*pConstraints)[idx1].IsBitTrue(idx2) == false) {
|
||||
ofs_1 << "C " << idx1 << " " << idx2 << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,36 +57,52 @@ enum FittingType {
|
|||
kNumFittingTypes
|
||||
};
|
||||
|
||||
class Array {
|
||||
class DynamicBitSet {
|
||||
const size_t bit_width_ = 64;
|
||||
size_t bit_size_;
|
||||
std::vector<uint64_t> bit_;
|
||||
|
||||
inline size_t GetIndex(size_t index) { return index / bit_width_; }
|
||||
|
||||
inline uint64_t GetBitMask(size_t index) { return (((uint64_t)0x1) << (bit_width_ - 1 - (index % bit_width_))); }
|
||||
|
||||
inline void Reset(uint64_t val) {
|
||||
bit_.clear();
|
||||
for (size_t i = 0; i < bit_size_; i++) {
|
||||
bit_.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Array(const size_t &rows, const size_t &cols) : rows_(rows), cols_(cols) {
|
||||
conflicts_array_ = std::make_unique<int[]>(rows * cols);
|
||||
for (uint32_t i = 0; i < rows * cols; i++) {
|
||||
conflicts_array_[i] = 1;
|
||||
explicit DynamicBitSet(size_t count) {
|
||||
bit_size_ = (count + bit_width_ - 1) / bit_width_;
|
||||
Reset(0x0);
|
||||
}
|
||||
|
||||
void SetBitTrue(size_t index, bool log = false) {
|
||||
if (log) {
|
||||
MS_LOG(INFO) << GetIndex(index) << " " << GetBitMask(index);
|
||||
}
|
||||
bit_[GetIndex(index)] |= GetBitMask(index);
|
||||
}
|
||||
|
||||
void SetBitFalse(size_t index) { bit_[GetIndex(index)] &= (~GetBitMask(index)); }
|
||||
|
||||
bool IsBitTrue(size_t index) { return (bit_[GetIndex(index)] & GetBitMask(index)) != 0x0; }
|
||||
|
||||
void Log() {
|
||||
std::cout << "Start Print Bitset ";
|
||||
for (size_t i = 0; i < bit_size_; i++) {
|
||||
std::cout << " bit [" << std::dec << i << "] = " << std::hex << bit_[i] << std::dec;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
friend void Union(DynamicBitSet *a, DynamicBitSet *b) {
|
||||
for (size_t i = 0; i < (*a).bit_size_; i++) {
|
||||
(*a).bit_[i] |= (*b).bit_[i];
|
||||
}
|
||||
}
|
||||
|
||||
Array(const Array &array) : rows_(array.rows_), cols_(array.cols_) {
|
||||
conflicts_array_ = std::make_unique<int[]>(array.rows_ * array.cols_);
|
||||
for (uint32_t i = 0; i < array.rows_ * array.cols_; i++) {
|
||||
conflicts_array_[i] = array.conflicts_array_[i];
|
||||
}
|
||||
}
|
||||
|
||||
Array &operator=(const Array &array) { return *this; }
|
||||
|
||||
int &operator()(const size_t &i, const size_t &j) {
|
||||
assert((i * cols_ + j) < (rows_ * cols_));
|
||||
return conflicts_array_[i * cols_ + j];
|
||||
}
|
||||
|
||||
const size_t &Rows() { return rows_; }
|
||||
const size_t &Cols() { return cols_; }
|
||||
|
||||
private:
|
||||
const size_t rows_;
|
||||
const size_t cols_;
|
||||
std::unique_ptr<int[]> conflicts_array_;
|
||||
};
|
||||
|
||||
struct SomasSolverTensorDesc {
|
||||
|
@ -140,14 +156,14 @@ class SomasSolverPre {
|
|||
size_t GetMaxOffset() { return max_offset_; }
|
||||
|
||||
Status Solving(const session::KernelGraph *graph, std::unordered_map<size_t, SomasSolverTensorDescPtr> *tensors,
|
||||
std::shared_ptr<Array> pConstraints, const vector<vector<size_t>> &continuous_v,
|
||||
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v,
|
||||
bool bVerifySolution, // true -> Check continuous and non overlapping constraints solution
|
||||
bool ball = true, // true -> run full set of heuristics, false -> run single heuristic specified
|
||||
SortingType sorting = kGreaterSizeSmallerIndex, FittingType fitting = kBest,
|
||||
AlgorithmType algorithm = kManyObjects);
|
||||
|
||||
void Log(const session::KernelGraph *graph, const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
|
||||
const std::shared_ptr<Array> &pConstraints_v, const vector<vector<size_t>> &continuous_v);
|
||||
std::vector<DynamicBitSet> *pConstraints_v, const vector<vector<size_t>> &continuous_v);
|
||||
|
||||
private:
|
||||
size_t max_offset_;
|
||||
|
|
|
@ -34,11 +34,13 @@ using SomasTensorPtr = std::shared_ptr<SomasTensor>;
|
|||
class SomasStream {
|
||||
public:
|
||||
using SomasStreamPtr = std::shared_ptr<SomasStream>;
|
||||
using SomasNodePtr = std::shared_ptr<SomasNode>;
|
||||
|
||||
// Attributes mutated in code
|
||||
std::vector<SomasTensorPtr> tensors_; // vector needed for same-stream loop in ConflictComputing()
|
||||
std::set<SomasStreamPtr> ancestor_streams_;
|
||||
std::set<SomasStreamPtr> ancestor_streams_group_;
|
||||
std::vector<SomasNodePtr> nodes_;
|
||||
|
||||
// Constructors/Destructors
|
||||
explicit SomasStream(int64_t id) : id_(id) {}
|
||||
|
|
Loading…
Reference in New Issue