!9521 [MEMOPT] Use bitset model to optimize somas

From: @laiyongqiang
Reviewed-by: @jjfeing,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2020-12-08 18:59:22 +08:00 committed by Gitee
commit fee3aee30e
10 changed files with 217 additions and 267 deletions

View File

@ -122,6 +122,7 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
node->scope_full_name_ = kernel->fullname_with_scope();
nodes_list_.push_back(node);
stream->nodes_.push_back(node);
auto key = kernel.get();
nodes_map_[key] = node;
node_index++;
@ -565,54 +566,6 @@ void Somas::PreprocessingConflicts() {
}
}
static bool LifetimeOverlap(lifetime_t lifetime1, lifetime_t lifetime2) {
size_t start1 = std::min(lifetime1.start_, lifetime1.end_);
size_t end1 = std::max(lifetime1.start_, lifetime1.end_);
size_t start2 = std::min(lifetime2.start_, lifetime2.end_);
size_t end2 = std::max(lifetime2.start_, lifetime2.end_);
return (std::max(end1, end2) - std::min(start1, start2) <= end2 - start2 + end1 - start1);
}
static bool Subset(std::set<SomasStreamPtr> streamSet1, std::set<SomasStreamPtr> streamSet2) {
for (auto stream : streamSet1) {
if (streamSet2.count(stream) == 0) {
return false;
}
}
return true;
}
static void EraseSet(std::set<SomasStreamPtr> *streamSet, std::set<SomasStreamPtr> removeStreamsSet) {
for (auto stream : removeStreamsSet) {
streamSet->erase(stream);
}
}
static bool ValidSubset(std::set<SomasStreamPtr> destStreams, std::set<SomasStreamPtr> ancestorsAndSelf,
SomasTensorPtr ancestorTensor, SomasTensorPtr tensor) {
MS_EXCEPTION_IF_NULL(ancestorTensor);
MS_EXCEPTION_IF_NULL(tensor);
for (auto stream : destStreams) {
if (ancestorsAndSelf.count(stream) == 0) {
return false;
}
if (stream != tensor->GetSourceStream()) {
MS_EXCEPTION_IF_NULL(tensor->GetSourceStream());
if (tensor->GetSourceStream()->ancestor_streams_group_.count(stream) == 0 &&
ancestorTensor->max_destination_id_[stream] >
tensor->GetSourceNode()->anc_stream_max_order_[stream->GetId()]) {
return false;
}
} else {
if (ancestorTensor->max_destination_id_[stream] >= tensor->lifetime_.start_) {
return false;
}
}
}
return true;
}
void Somas::ComputeConflictPairs() {
if (tensors_list_.empty()) {
MS_LOG(INFO) << "No Tensor for Conflict computing";
@ -623,162 +576,145 @@ void Somas::ComputeConflictPairs() {
PreprocessingConflicts();
MS_LOG(INFO) << "End Preprocessing Conflicts";
MS_LOG(INFO) << "Start Array Initialization";
cannot_reuse_ =
std::make_shared<Array>(tensors_list_.back()->GetId() + 1,
tensors_list_.back()->GetId() + 1); // array size is (max_id + 1) x (max_id + 1)
MS_LOG(INFO) << "End Array Initialization";
MS_LOG(INFO) << "Start Conflict Computing (Bitset Model)";
std::sort(nodes_list_.begin(), nodes_list_.end(), NodeSort);
MS_LOG(INFO) << "Start Conflict Computing";
// Loop to add edges within each stream (node order within stream)
for (const auto &stream : streams_list_) {
auto &nodes = stream->nodes_;
std::sort(nodes.begin(), nodes.end(), NodeSort);
for (size_t i = 1; i < nodes.size(); i++) {
const auto &previous_node = nodes[i - 1];
const auto &current_node = nodes[i];
current_node->ancestor_nodes_.insert(previous_node);
}
}
size_t count_reuse = 0;
// Loop to add edges from end to beginning of next group
for (const auto &group : streams_groups_) {
for (size_t i = 1; i < group.size(); i++) {
int64_t previous_stream = group[i - 1];
int64_t current_stream = group[i];
// Loop for ancestor stream groups reuse
for (auto stream : streams_list_) {
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_group_;
auto it =
std::find_if(streams_list_.begin(), streams_list_.end(),
[previous_stream](const SomasStreamPtr &stream) { return stream->GetId() == previous_stream; });
if (it == streams_list_.end()) {
continue;
}
auto &last_node_in_prev_stream = (*it)->nodes_.back();
std::set<SomasStreamPtr> ancestors_and_self = ancestors;
ancestors_and_self.insert(stream);
it = std::find_if(streams_list_.begin(), streams_list_.end(),
[current_stream](const SomasStreamPtr &stream) { return stream->GetId() == current_stream; });
if (it == streams_list_.end()) {
continue;
}
auto &first_node_in_cur_stream = (*it)->nodes_.front();
for (auto ancestor_stream : ancestors) {
for (auto ancestor_tensor : ancestor_stream->tensors_) {
if (ancestor_tensor->GetAlignedSize() == 0) continue;
if (ancestor_tensor->IsLifelong()) continue;
if (ancestor_tensor->IsSemiLifelongEnd()) continue;
if (ancestor_tensor->IsRefOverlap()) continue;
first_node_in_cur_stream->ancestor_nodes_.insert(last_node_in_prev_stream);
}
}
if (!ancestor_tensor->IsBetweenStreams() || Subset(ancestor_tensor->destinationStreams_, ancestors)) {
for (auto tensor : stream->tensors_) {
if (tensor->IsGap()) continue;
if (tensor->GetAlignedSize() == 0) continue;
if (tensor->IsLifelong()) continue;
if (tensor->IsSemiLifelongStart()) continue;
if (tensor->IsRefOverlap()) continue;
// Loop to avoid tensors with empty destinations (add itself)
for (const auto &tensor : tensors_list_) {
if (tensor->destinations_.size() == 0) {
tensor->destinations_.insert(tensor->GetSourceNode());
}
}
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
count_reuse++;
}
MS_LOG(INFO) << "Start Bitset";
std::vector<DynamicBitSet> nodes_dependency;
size_t count = nodes_list_.back()->GetId() + 1;
for (size_t i = 0; i < count; i++) {
nodes_dependency.emplace_back(count);
}
MS_LOG(INFO) << "Start Path Computing";
// Loop to compute ancestor paths via bitset for time dependence
for (const auto &node : nodes_list_) {
for (const auto &ancestor : node->ancestor_nodes_) {
nodes_dependency[node->GetId()].SetBitTrue(ancestor->GetId());
Union(&nodes_dependency[node->GetId()], &nodes_dependency[ancestor->GetId()]);
}
}
MS_LOG(INFO) << "End Path Computing";
MS_LOG(INFO) << "Start Tensor Relation Computing";
count = tensors_list_.back()->GetId() + 1;
for (size_t i = 0; i < count; i++) {
tensor_relation.emplace_back(count);
}
for (size_t i = 0; i < tensors_list_.size(); i++) {
for (size_t j = i + 1; j < tensors_list_.size(); j++) {
auto t0 = tensors_list_[i];
auto t1 = tensors_list_[j];
if (t0 == t1 || t0->IsGap() || t1->IsGap() || t0->IsLifelong() || t1->IsLifelong() || t0->IsRefOverlap() ||
t1->IsRefOverlap() || t0->GetAlignedSize() == 0 || t1->GetAlignedSize() == 0) {
continue;
}
size_t t0_src_node = t0->GetSourceNode()->GetId();
size_t t1_src_node = t1->GetSourceNode()->GetId();
if (t0_src_node == t1_src_node) {
continue;
}
bool reuse = true;
bool all_dst_depend = false;
// check t0's all consumers is t1's source node's dependency or not
for (const auto &dst_node : t0->destinations_) {
if (nodes_dependency[t1_src_node].IsBitTrue(dst_node->GetId()) == false) {
// t0's consumer is not in t1's source node's dependency, not sure this consumer is done or not when t1
// produced
reuse = false;
all_dst_depend = false;
break;
} else if (t1_src_node == dst_node->GetId()) {
// t0 is t1's source node's input, can't reuse
reuse = false;
all_dst_depend = true;
break;
} else {
for (auto tensor : stream->tensors_) {
if (Subset(ancestor_tensor->destinationStreams_, ancestors_and_self) &&
ancestor_tensor->max_destination_id_[tensor->GetSourceStream()] < tensor->lifetime_.start_) {
if (tensor->IsGap()) continue;
if (tensor->GetAlignedSize() == 0) continue;
if (tensor->IsLifelong()) continue;
if (tensor->IsSemiLifelongStart()) continue;
if (tensor->IsRefOverlap()) continue;
// t0's consumer is in t1's source node's dependency, this consumer is done when t1 produced
reuse = true;
all_dst_depend = true;
}
}
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
count_reuse++;
}
if (all_dst_depend == false) {
// check t1's all consumers is t0's source node's dependency or not
reuse = true;
for (const auto &dst_node : t1->destinations_) {
if (nodes_dependency[t0_src_node].IsBitTrue(dst_node->GetId()) == false) {
reuse = false;
all_dst_depend = false;
break;
} else if (t0_src_node == dst_node->GetId()) {
reuse = false;
all_dst_depend = true;
break;
} else {
reuse = true;
all_dst_depend = true;
}
}
}
}
}
// Loop for ancestor streams (no groups)
for (auto stream : streams_list_) {
auto ancestors_no_groups = stream->ancestor_streams_;
EraseSet(&ancestors_no_groups, stream->ancestor_streams_group_);
for (auto ancestor_stream : ancestors_no_groups) {
for (auto ancestor_tensor : ancestor_stream->tensors_) {
if (ancestor_tensor->GetAlignedSize() == 0) continue;
if (ancestor_tensor->IsLifelong()) continue;
if (ancestor_tensor->IsSemiLifelongEnd()) continue;
if (ancestor_tensor->IsRefOverlap()) continue;
if (!ancestor_tensor->IsBetweenStreams()) {
for (auto tensor : stream->tensors_) {
if (tensor->IsGap()) continue;
if (tensor->GetAlignedSize() == 0) continue;
if (tensor->IsLifelong()) continue;
if (tensor->IsSemiLifelongStart()) continue;
if (tensor->IsRefOverlap()) continue;
size_t max_ancestor_order = tensor->GetSourceNode()->anc_stream_max_order_[ancestor_stream->GetId()];
if (ancestor_tensor->lifetime_.end_ <= max_ancestor_order) {
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
count_reuse++;
}
}
} else { // ancestor tensor goes to another stream (might go to same stream also)
std::set<SomasStreamPtr> dest_streams = ancestor_tensor->destinationStreams_;
std::set<SomasStreamPtr> ancestors = stream->ancestor_streams_;
std::set<SomasStreamPtr> ancestors_and_self = ancestors;
ancestors_and_self.insert(stream);
for (auto tensor : stream->tensors_) {
if (tensor->IsGap()) continue;
if (tensor->GetAlignedSize() == 0) continue;
if (tensor->IsLifelong()) continue;
if (tensor->IsSemiLifelongStart()) continue;
if (tensor->IsRefOverlap()) continue;
if (ValidSubset(dest_streams, ancestors_and_self, ancestor_tensor, tensor)) {
(*cannot_reuse_)(ancestor_tensor->GetId(), tensor->GetId()) = 0;
(*cannot_reuse_)(tensor->GetId(), ancestor_tensor->GetId()) = 0;
count_reuse++;
}
}
}
if (all_dst_depend == true && reuse == true) {
tensor_relation[t0->GetId()].SetBitTrue(t1->GetId());
tensor_relation[t1->GetId()].SetBitTrue(t0->GetId());
}
}
}
// Loop for same stream
for (auto stream : streams_list_) {
MS_EXCEPTION_IF_NULL(stream);
for (auto tensor1 : stream->tensors_) {
if (tensor1->GetAlignedSize() == 0) continue;
if (tensor1->IsLifelong()) continue;
if (tensor1->IsRefOverlap()) continue;
for (auto tensor2 : stream->tensors_) {
if (tensor2->GetId() >= tensor1->GetId())
break; // keep only when tensors kept sorted in tensors-vector of each stream, otherwise remove
if (tensor2->GetAlignedSize() == 0) continue;
if (tensor2->IsLifelong()) continue;
if (tensor2->IsRefOverlap()) continue;
// Between streams extra safety
if (tensor1->IsBetweenStreams() && tensor2->IsBetweenStreams()) continue;
// Check lifetime overlap
lifetime_t lifetime1 = tensor1->lifetime_;
lifetime_t lifetime2 = tensor2->lifetime_;
if (!LifetimeOverlap(lifetime1, lifetime2)) {
// Between-streams extra safety
if (tensor1->IsBetweenStreams() && lifetime1.end_ < lifetime2.start_) continue;
if (tensor2->IsBetweenStreams() && lifetime2.end_ < lifetime1.start_) continue;
// Semi-lifelong extra safety
if (lifetime1.end_ < lifetime2.start_ && (tensor2->IsSemiLifelongStart() || tensor1->IsSemiLifelongEnd()))
continue;
if (lifetime2.end_ < lifetime1.start_ && (tensor1->IsSemiLifelongStart() || tensor2->IsSemiLifelongEnd()))
continue;
// If arrived here, allow reuse
(*cannot_reuse_)(tensor2->GetId(), tensor1->GetId()) = 0;
(*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) = 0;
count_reuse++;
}
}
}
}
MS_LOG(INFO) << "End Conflict Computing";
MS_LOG(INFO) << "Found " << count_reuse << " tensor pairs of allowed reusability";
MS_LOG(INFO) << "End Tensor Relation Computing";
MS_LOG(INFO) << "End Conflict Computing (Bitset Model)";
}
bool Somas::NodeSort(SomasNodePtr node1, SomasNodePtr node2) { return node1->GetId() < node2->GetId(); }
bool Somas::Assign(const session::KernelGraph *graph) {
if (tensors_list_.empty()) {
MS_LOG(INFO) << "No Tensor for Assigner";
@ -795,13 +731,13 @@ bool Somas::Assign(const session::KernelGraph *graph) {
// Keep all constraints for first tensor in list
size_t tid_0 = ref_node_list[0];
for (SomasTensorPtr tensor : tensors_list_) {
if ((*cannot_reuse_)(tid_0, tensor->GetId()) == 1) {
if (tensor_relation[tid_0].IsBitTrue(tensor->GetId()) == false) {
continue;
}
for (size_t tid : ref_node_list) {
if ((*cannot_reuse_)(tid, tensor->GetId()) == 1) {
(*cannot_reuse_)(tid_0, tensor->GetId()) = 1;
(*cannot_reuse_)(tensor->GetId(), tid_0) = 1;
if (tensor_relation[tid].IsBitTrue(tensor->GetId()) == false) {
tensor_relation[tid_0].SetBitFalse(tensor->GetId());
tensor_relation[tensor->GetId()].SetBitFalse(tid_0);
break;
}
}
@ -921,8 +857,8 @@ bool Somas::Assign(const session::KernelGraph *graph) {
for (auto ref_overlap_list : ref_overlap_constraints_) {
for (size_t tid_1 : ref_overlap_list) {
for (size_t tid_2 : ref_overlap_list) {
(*cannot_reuse_)(tid_1, tid_2) = 0;
(*cannot_reuse_)(tid_2, tid_1) = 0;
tensor_relation[tid_1].SetBitTrue(tid_2);
tensor_relation[tid_2].SetBitTrue(tid_1);
}
}
}
@ -932,7 +868,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
for (auto tensor1 : tensors_list_) {
size_t count_constraints = 0;
for (auto tensor2 : tensors_list_) {
if ((*cannot_reuse_)(tensor1->GetId(), tensor2->GetId()) == 1) {
if (tensor_relation[tensor1->GetId()].IsBitTrue(tensor2->GetId()) == false) {
count_constraints++;
}
}
@ -943,7 +879,7 @@ bool Somas::Assign(const session::KernelGraph *graph) {
MS_LOG(INFO) << "Start Contiguous Gaps Preprocessing";
for (auto contiguous_list : contiguous_tensors_list_) {
if (contiguous_list.size() < 3) {
MS_LOG(ERROR) << "contiguous_list should has at least one input and two gap, now it is "
MS_LOG(ERROR) << "contiguous_list should have at least one input and two gap, now it is "
<< contiguous_list.size();
}
size_t front_gap_id = contiguous_list[0];
@ -959,10 +895,20 @@ bool Somas::Assign(const session::KernelGraph *graph) {
size_t back_neighbour_id = contiguous_list[contiguous_list.size() - 2];
for (SomasTensorPtr tensor : tensors_list_) {
MS_EXCEPTION_IF_NULL(tensor);
(*cannot_reuse_)(tensor->GetId(), front_gap_id) = (*cannot_reuse_)(tensor->GetId(), front_neighbour_id);
(*cannot_reuse_)(front_gap_id, tensor->GetId()) = (*cannot_reuse_)(front_neighbour_id, tensor->GetId());
(*cannot_reuse_)(tensor->GetId(), back_gap_id) = (*cannot_reuse_)(tensor->GetId(), back_neighbour_id);
(*cannot_reuse_)(back_gap_id, tensor->GetId()) = (*cannot_reuse_)(back_neighbour_id, tensor->GetId());
if (tensor_relation[tensor->GetId()].IsBitTrue(front_neighbour_id) == false) {
tensor_relation[tensor->GetId()].SetBitFalse(front_gap_id);
tensor_relation[front_gap_id].SetBitFalse(tensor->GetId());
} else {
tensor_relation[tensor->GetId()].SetBitTrue(front_gap_id);
tensor_relation[front_gap_id].SetBitTrue(tensor->GetId());
}
if (tensor_relation[tensor->GetId()].IsBitTrue(back_neighbour_id) == false) {
tensor_relation[tensor->GetId()].SetBitFalse(back_gap_id);
tensor_relation[back_gap_id].SetBitFalse(tensor->GetId());
} else {
tensor_relation[tensor->GetId()].SetBitTrue(back_gap_id);
tensor_relation[back_gap_id].SetBitTrue(tensor->GetId());
}
}
SomasTensorPtr front_neighbour = tensors_map_[front_neighbour_id];
SomasTensorPtr back_neighbour = tensors_map_[back_neighbour_id];
@ -995,8 +941,8 @@ bool Somas::Assign(const session::KernelGraph *graph) {
}
somas_solver_ = std::make_shared<SomasSolverPre>();
auto status =
somas_solver_->Solving(graph, &solver_tensor_desc_list_, cannot_reuse_, contiguous_tensors_list_removed_ref, false);
auto status = somas_solver_->Solving(graph, &solver_tensor_desc_list_, &tensor_relation,
contiguous_tensors_list_removed_ref, false);
MS_LOG(INFO) << "End Solving";
if (status != SUCCESS) {
GenStatisticInfo();

View File

@ -51,6 +51,9 @@ class Somas {
void DumpSomasBasicIR(const string filename);
void DumpSomasMemoryIR(const string filename);
static bool NodeSort(SomasNodePtr, SomasNodePtr);
std::vector<DynamicBitSet> tensor_relation;
private:
// Maps
std::unordered_map<size_t, SomasTensorPtr> tensors_map_;
@ -68,9 +71,6 @@ class Somas {
std::unordered_map<size_t, SomasSolverTensorDescPtr> solver_tensor_desc_list_;
SomasSolverPrePtr somas_solver_;
// Constraints
std::shared_ptr<Array> cannot_reuse_;
// Contiguous list
std::vector<vector<size_t>> contiguous_tensors_list_;

View File

@ -42,8 +42,8 @@ class SomasNode {
// Public attributes (mutated in code)
std::string scope_full_name_;
std::set<SomasNodePtr>
ancestor_nodes_; // keeping only distance *one* ancestor nodes; enough to ComputeAncestorNodes()
// node's dependency including data dependency and time dependency
std::set<SomasNodePtr> ancestor_nodes_;
std::set<SomasTensorPtr> tensors_;
std::vector<SomasTensorPtr> input_tensors_;

View File

@ -117,15 +117,15 @@ void FootPrint::Merge(vector<Interval> *interval_v, stack<Interval> *s) {
return;
}
void FootPrint::ConstrainedBLocks(const std::shared_ptr<Array> &constraints, const BlockTensor &b1,
const BlockTensor &b2, vector<Interval> *oInterval) {
void FootPrint::ConstrainedBLocks(std::vector<DynamicBitSet> *constraints, const BlockTensor &b1, const BlockTensor &b2,
vector<Interval> *oInterval) {
MS_EXCEPTION_IF_NULL(oInterval);
// propagate
size_t acum = m_offset_;
for (SomasSolverTensorDescPtr p1 = b1.m_start_tensor_; NULL != p1; p1 = p1->right_) {
for (SomasSolverTensorDescPtr p2 = b2.m_start_tensor_; NULL != p2; p2 = p2->right_) {
if ((*constraints)(p1->index_, p2->index_) == 1) {
if ((*constraints)[p1->index_].IsBitTrue(p2->index_) == false) {
Interval a = Interval(acum, acum + p1->size_);
Interval b = Interval(p2);
if (a.lb() < b.ub()) {
@ -136,7 +136,7 @@ void FootPrint::ConstrainedBLocks(const std::shared_ptr<Array> &constraints, con
acum += p1->size_;
}
}
bool FootPrint::findOffset(const std::shared_ptr<Array> &constraints, const BlockTensor &block, size_t *offset) {
bool FootPrint::findOffset(std::vector<DynamicBitSet> *constraints, const BlockTensor &block, size_t *offset) {
MS_EXCEPTION_IF_NULL(offset);
bool bretval = true;
vector<Interval> l_interval;
@ -150,7 +150,7 @@ bool FootPrint::findOffset(const std::shared_ptr<Array> &constraints, const Bloc
// transform constrained tensors in non eligible intervals
for (size_t i = 0; i < m_starts_.size(); i++) {
if (block.Alone() && m_starts_[i]->Alone() &&
(*constraints)(block.m_start_tensor_->index_, m_starts_[i]->m_start_tensor_->index_)) {
(*constraints)[block.m_start_tensor_->index_].IsBitTrue(m_starts_[i]->m_start_tensor_->index_) == false) {
if (m_algorithm_ != 1 && i == 0) return false;
Interval It = Interval(m_starts_[i]->m_start_tensor_);
l_interval.emplace_back(It);
@ -201,7 +201,7 @@ void FootPrint::printStats() {
MS_LOG(DEBUG) << "Footprint blocks: " << m_starts_.size() << " \toffset: " << m_offset_;
}
bool FastHeuristic::Eval(vector<BlockTensor> *block_tensors_v, std::shared_ptr<FootPrint> foot_print,
const std::shared_ptr<Array> &pConstraints) {
std::vector<DynamicBitSet> *pConstraints) {
MS_EXCEPTION_IF_NULL(foot_print);
auto start = std::chrono::system_clock::now();

View File

@ -143,8 +143,8 @@ class FootPrint : public std::enable_shared_from_this<FootPrint> {
void Destroy();
const size_t getOffset() { return m_offset_; }
void setOffset(const size_t &offset) { m_offset_ = offset; }
bool findOffset(const std::shared_ptr<Array> &constraints, const BlockTensor &block, size_t *offset);
void ConstrainedBLocks(const std::shared_ptr<Array> &constraints, const BlockTensor &b1, const BlockTensor &b2,
bool findOffset(std::vector<DynamicBitSet> *constraints, const BlockTensor &block, size_t *offset);
void ConstrainedBLocks(std::vector<DynamicBitSet> *constraints, const BlockTensor &b1, const BlockTensor &b2,
vector<Interval> *oInterval_l);
void Merge(vector<Interval> *l_interval, stack<Interval> *l_merged);
bool findFirst(stack<Interval> *merged, const BlockTensor &block, size_t *offset);
@ -167,7 +167,7 @@ class FastHeuristic {
void setAlignment(const size_t &a) { m_alignment_ = a; }
void Destroy();
bool Eval(vector<BlockTensor> *block_tensors_v, std::shared_ptr<FootPrint> foot_print,
const std::shared_ptr<Array> &pConstraints);
std::vector<DynamicBitSet> *pConstraints);
private:
size_t m_alignment_;

View File

@ -169,16 +169,14 @@ bool SomasSolverCore::Verify(const size_t &upperbound) {
t2 = t2_.second;
if (t1->index_ == t2->index_) continue;
bool blifelong = (t1->lifelong_ || t2->lifelong_) && (t1->index_ != t2->index_);
const size_t continuous = 2;
const size_t conflict = 1;
if ((*constraints_)(t1->index_, t2->index_) == continuous) { // continuous constraint
// t1 must be continous to t2
if (t2->right_ == t1) { // continuous constraint
// t1 must be continuous to t2
bool bcontinuous = t1->offset_ == (t2->offset_ + t2->size_);
if (!bcontinuous) {
MS_LOG(WARNING) << "Continuous constraint violation in tensors " << t1->index_ << " and" << t2->index_;
retval = false;
}
} else if (blifelong || (*constraints_)(t1->index_, t2->index_) == conflict) { // conflict constraint
} else if (blifelong || constraints_[t1->index_].IsBitTrue(t2->index_) == false) { // conflict constraint
size_t t1_ub = t1->offset_ + t1->size_;
size_t t2_ub = t2->offset_ + t2->size_;
bool b_overlap_lb = ((t2->offset_ >= t1->offset_) && (t2->offset_ < t1_ub));
@ -336,7 +334,7 @@ size_t SomasSolverCore::Search(const std::shared_ptr<FootPrint> &pFootprint) {
FastHeuristic fh;
MS_LOG(INFO) << "Calling FastSolver Search for " << block_tensors_.size() << " tensors ";
auto start = std::chrono::system_clock::now();
if (fh.Eval(&block_tensors_, pFootprint, constraints_)) {
if (fh.Eval(&block_tensors_, pFootprint, &constraints_)) {
result = pFootprint->Result();
auto end = std::chrono::system_clock::now();
timing_ = std::chrono::duration_cast<std::chrono::milliseconds>((end - start)).count();

View File

@ -33,9 +33,9 @@ class SomasSolverCore {
public:
/// Interface Function: receive parameters, creates the model to solve and then save the result
SomasSolverCore(const std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
const std::shared_ptr<Array> &constraints)
std::vector<DynamicBitSet> *constraints)
: tensors_(tensors),
constraints_(constraints),
constraints_(*constraints),
upperbound_(SIZE_MAX),
timing_(0),
lifelongmemory_(0),
@ -69,7 +69,7 @@ class SomasSolverCore {
private:
std::unordered_map<size_t, SomasSolverTensorDescPtr> tensors_;
vector<BlockTensor> block_tensors_;
std::shared_ptr<Array> constraints_;
std::vector<DynamicBitSet> constraints_;
size_t upperbound_{0};
size_t timing_{0};
size_t lifelongmemory_{0};

View File

@ -27,25 +27,15 @@ namespace mindspore {
namespace somas {
Status SomasSolverPre::Solving(const session::KernelGraph *graph,
std::unordered_map<size_t, SomasSolverTensorDescPtr> *ptensors,
std::shared_ptr<Array> pConstraints, const vector<vector<size_t>> &continuous_v,
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v,
bool bVerifySolution, bool ball, SortingType sorting, FittingType fitting,
AlgorithmType algorithm) {
Status retval = SUCCESS;
try {
std::unordered_map<size_t, SomasSolverTensorDescPtr> &tensors = *ptensors;
std::unordered_map<size_t, SomasSolverTensorDescPtr>::iterator max =
std::max_element(tensors.begin(), tensors.end(),
[](const std::pair<size_t, SomasSolverTensorDescPtr> &a,
const std::pair<size_t, SomasSolverTensorDescPtr> &b) { return a.first < b.first; });
size_t maxIndex = max->first;
if (maxIndex > pConstraints->Rows() - 1) {
MS_LOG(WARNING) << "ERROR: MaxIndex invalid, MaxIndex " << maxIndex << ", Rows " << pConstraints->Rows();
return FAILED;
}
MS_LOG(INFO) << "Filling in constraints matrix..";
uint32_t continuous_cnt = 0;
// creating S Lists
for (auto &aux : continuous_v) {
for (uint32_t i = 0; i < aux.size() - 1; i++) {
@ -60,8 +50,6 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph,
return FAILED;
}
const size_t continuous = 2;
(*pConstraints)(index2, index1) = continuous;
if (tensors[index1]->right_)
MS_LOG(WARNING) << "Warning:tensor " << index1
<< " already has a right tensor (id: " << tensors[index1]->right_->index_;
@ -104,7 +92,7 @@ Status SomasSolverPre::Solving(const session::KernelGraph *graph,
void SomasSolverPre::Log(const session::KernelGraph *graph,
const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
const std::shared_ptr<Array> &pConstraints, const vector<vector<size_t>> &continuous_v) {
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v) {
MS_LOG(INFO) << "SomasSolver::Log Writing somas-input.txt..";
auto context_ptr = MsContext::GetInstance();
@ -143,7 +131,7 @@ void SomasSolverPre::Log(const session::KernelGraph *graph,
for (auto &t2 : tensors) {
size_t idx1 = t1.first;
size_t idx2 = t2.first;
if ((idx1 != idx2) && (*pConstraints)(idx1, idx2) == 1) {
if ((idx1 != idx2) && (*pConstraints)[idx1].IsBitTrue(idx2) == false) {
ofs_1 << "C " << idx1 << " " << idx2 << std::endl;
}
}

View File

@ -57,36 +57,52 @@ enum FittingType {
kNumFittingTypes
};
class Array {
class DynamicBitSet {
const size_t bit_width_ = 64;
size_t bit_size_;
std::vector<uint64_t> bit_;
inline size_t GetIndex(size_t index) { return index / bit_width_; }
inline uint64_t GetBitMask(size_t index) { return (((uint64_t)0x1) << (bit_width_ - 1 - (index % bit_width_))); }
inline void Reset(uint64_t val) {
bit_.clear();
for (size_t i = 0; i < bit_size_; i++) {
bit_.push_back(val);
}
}
public:
Array(const size_t &rows, const size_t &cols) : rows_(rows), cols_(cols) {
conflicts_array_ = std::make_unique<int[]>(rows * cols);
for (uint32_t i = 0; i < rows * cols; i++) {
conflicts_array_[i] = 1;
explicit DynamicBitSet(size_t count) {
bit_size_ = (count + bit_width_ - 1) / bit_width_;
Reset(0x0);
}
void SetBitTrue(size_t index, bool log = false) {
if (log) {
MS_LOG(INFO) << GetIndex(index) << " " << GetBitMask(index);
}
bit_[GetIndex(index)] |= GetBitMask(index);
}
void SetBitFalse(size_t index) { bit_[GetIndex(index)] &= (~GetBitMask(index)); }
bool IsBitTrue(size_t index) { return (bit_[GetIndex(index)] & GetBitMask(index)) != 0x0; }
void Log() {
std::cout << "Start Print Bitset ";
for (size_t i = 0; i < bit_size_; i++) {
std::cout << " bit [" << std::dec << i << "] = " << std::hex << bit_[i] << std::dec;
}
std::cout << std::endl;
}
friend void Union(DynamicBitSet *a, DynamicBitSet *b) {
for (size_t i = 0; i < (*a).bit_size_; i++) {
(*a).bit_[i] |= (*b).bit_[i];
}
}
Array(const Array &array) : rows_(array.rows_), cols_(array.cols_) {
conflicts_array_ = std::make_unique<int[]>(array.rows_ * array.cols_);
for (uint32_t i = 0; i < array.rows_ * array.cols_; i++) {
conflicts_array_[i] = array.conflicts_array_[i];
}
}
Array &operator=(const Array &array) { return *this; }
int &operator()(const size_t &i, const size_t &j) {
assert((i * cols_ + j) < (rows_ * cols_));
return conflicts_array_[i * cols_ + j];
}
const size_t &Rows() { return rows_; }
const size_t &Cols() { return cols_; }
private:
const size_t rows_;
const size_t cols_;
std::unique_ptr<int[]> conflicts_array_;
};
struct SomasSolverTensorDesc {
@ -140,14 +156,14 @@ class SomasSolverPre {
size_t GetMaxOffset() { return max_offset_; }
Status Solving(const session::KernelGraph *graph, std::unordered_map<size_t, SomasSolverTensorDescPtr> *tensors,
std::shared_ptr<Array> pConstraints, const vector<vector<size_t>> &continuous_v,
std::vector<DynamicBitSet> *pConstraints, const vector<vector<size_t>> &continuous_v,
bool bVerifySolution, // true -> Check continuous and non overlapping constraints solution
bool ball = true, // true -> run full set of heuristics, false -> run single heuristic specified
SortingType sorting = kGreaterSizeSmallerIndex, FittingType fitting = kBest,
AlgorithmType algorithm = kManyObjects);
void Log(const session::KernelGraph *graph, const unordered_map<size_t, SomasSolverTensorDescPtr> &tensors,
const std::shared_ptr<Array> &pConstraints_v, const vector<vector<size_t>> &continuous_v);
std::vector<DynamicBitSet> *pConstraints_v, const vector<vector<size_t>> &continuous_v);
private:
size_t max_offset_;

View File

@ -34,11 +34,13 @@ using SomasTensorPtr = std::shared_ptr<SomasTensor>;
class SomasStream {
public:
using SomasStreamPtr = std::shared_ptr<SomasStream>;
using SomasNodePtr = std::shared_ptr<SomasNode>;
// Attributes mutated in code
std::vector<SomasTensorPtr> tensors_; // vector needed for same-stream loop in ConflictComputing()
std::set<SomasStreamPtr> ancestor_streams_;
std::set<SomasStreamPtr> ancestor_streams_group_;
std::vector<SomasNodePtr> nodes_;
// Constructors/Destructors
explicit SomasStream(int64_t id) : id_(id) {}