forked from mindspore-Ecosystem/mindspore
!8292 handling contiguous ref node
From: @laiyongqiang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
dc0cf6f66c
|
@ -52,8 +52,6 @@ bool Somas::Allocate(const session::KernelGraph *graph) {
|
|||
MS_LOG(EXCEPTION) << "Somas Initialize Failed.";
|
||||
}
|
||||
|
||||
GenStatisticInfo();
|
||||
|
||||
// Computing Conflict pairs
|
||||
MS_LOG(INFO) << "Start Computing Conflict Pairs";
|
||||
ComputeConflictPairs();
|
||||
|
@ -64,6 +62,7 @@ bool Somas::Allocate(const session::KernelGraph *graph) {
|
|||
MS_LOG(EXCEPTION) << "Somas Assign Failed.";
|
||||
}
|
||||
|
||||
GenStatisticInfo();
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -459,6 +458,7 @@ SomasTensorPtr Somas::CreateGapTensor(size_t gap_tensor_id) {
|
|||
gap_tensor->type_ = kGap;
|
||||
gap_tensor->lifetime_.start_ = 0;
|
||||
gap_tensor->lifetime_.end_ = 0xffff;
|
||||
gap_tensor->aligned_size_ = gap_size;
|
||||
tensors_map_[gap_tensor->GetId()] = gap_tensor;
|
||||
tensors_list_.push_back(gap_tensor);
|
||||
return gap_tensor;
|
||||
|
@ -474,32 +474,38 @@ void Somas::GenContiguousList(const session::KernelGraph *graph) {
|
|||
}
|
||||
std::vector<size_t> inputs;
|
||||
auto input_before_gap = CreateGapTensor(gap_tensor_id);
|
||||
input_before_gap->contiguous_ = true;
|
||||
gap_tensor_id++;
|
||||
inputs.push_back(input_before_gap->GetId());
|
||||
|
||||
for (const auto &input_tensor : node->input_tensors_) {
|
||||
comm_input_total_size_ += input_tensor->aligned_size_;
|
||||
input_tensor->contiguous_ = true;
|
||||
inputs.push_back(input_tensor->GetId());
|
||||
}
|
||||
|
||||
auto input_after_gap = CreateGapTensor(gap_tensor_id);
|
||||
gap_tensor_id++;
|
||||
input_after_gap->contiguous_ = true;
|
||||
inputs.push_back(input_after_gap->GetId());
|
||||
contiguous_tensors_list_.push_back(inputs);
|
||||
|
||||
std::vector<size_t> outputs;
|
||||
auto output_before_gap = CreateGapTensor(gap_tensor_id);
|
||||
gap_tensor_id++;
|
||||
output_before_gap->contiguous_ = true;
|
||||
outputs.push_back(output_before_gap->GetId());
|
||||
|
||||
for (const auto &output_tensor : node->output_tensors_) {
|
||||
output_tensor->lifelong_value_ = kLifeLongGraphStart;
|
||||
comm_output_total_size_ += output_tensor->aligned_size_;
|
||||
output_tensor->contiguous_ = true;
|
||||
outputs.push_back(output_tensor->GetId());
|
||||
}
|
||||
|
||||
auto output_after_gap = CreateGapTensor(gap_tensor_id);
|
||||
gap_tensor_id++;
|
||||
output_after_gap->contiguous_ = true;
|
||||
outputs.push_back(output_after_gap->GetId());
|
||||
contiguous_tensors_list_.push_back(outputs);
|
||||
}
|
||||
|
@ -785,11 +791,17 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
|
||||
// Ref Node Preprocessing
|
||||
MS_LOG(INFO) << "Start Solving Preprocessing for Ref Node";
|
||||
std::map<size_t, size_t> contiguous_ref_map;
|
||||
for (auto ref_node_list : ref_node_constraints_) {
|
||||
// Count contiguous tensors in ref list
|
||||
size_t contiguous_in_ref_list = std::count_if(ref_node_list.begin(), ref_node_list.end(),
|
||||
[this](size_t tid) { return tensors_map_[tid]->contiguous_; });
|
||||
// Keep all constraints for first tensor in list
|
||||
size_t tid_0 = ref_node_list[0];
|
||||
for (auto tensor : tensors_list_) {
|
||||
if ((*cannot_reuse_)(tid_0, tensor->GetId()) == 1) continue;
|
||||
for (SomasTensorPtr tensor : tensors_list_) {
|
||||
if ((*cannot_reuse_)(tid_0, tensor->GetId()) == 1) {
|
||||
continue;
|
||||
}
|
||||
for (size_t tid : ref_node_list) {
|
||||
if ((*cannot_reuse_)(tid, tensor->GetId()) == 1) {
|
||||
(*cannot_reuse_)(tid_0, tensor->GetId()) = 1;
|
||||
|
@ -798,9 +810,110 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Set rest to size 0, so that solver ignores them
|
||||
// Set rest to size 0, so that solver ignores them (if not contiguous)
|
||||
for (size_t i = 1; i < ref_node_list.size(); ++i) {
|
||||
tensors_map_[ref_node_list[i]]->aligned_size_ = 0;
|
||||
if (!tensors_map_[ref_node_list[i]]->contiguous_) {
|
||||
tensors_map_[ref_node_list[i]]->aligned_size_ = 0;
|
||||
}
|
||||
}
|
||||
// Keep info about contiguous and check for errors
|
||||
if (ref_node_list.size() > 2 && contiguous_in_ref_list > 0) {
|
||||
MS_LOG(WARNING) << "Ref node of size greater than two with at least one contiguous tensor in";
|
||||
}
|
||||
if (ref_node_list.size() == 2 && contiguous_in_ref_list == 1) {
|
||||
MS_LOG(WARNING) << "Ref node of size two with only one contiguous tensor" << ref_node_list[0] << ":"
|
||||
<< tensors_map_[ref_node_list[0]]->contiguous_ << ", " << ref_node_list[1] << ":"
|
||||
<< tensors_map_[ref_node_list[1]]->contiguous_;
|
||||
}
|
||||
if (ref_node_list.size() == 2 && contiguous_in_ref_list == 2) {
|
||||
contiguous_ref_map[ref_node_list[0]] = ref_node_list[1];
|
||||
}
|
||||
}
|
||||
// Handle contiguous ref node (remove ref from contiguous_tensors_list_)
|
||||
std::map<size_t, size_t> contiguous_ref_list_map;
|
||||
std::map<size_t, std::map<size_t, std::set<size_t>>> contiguous_ref_list_error_check_map;
|
||||
for (auto ref_pair : contiguous_ref_map) {
|
||||
size_t ref_first = ref_pair.first;
|
||||
size_t ref_second = ref_pair.second;
|
||||
bool found_first = false;
|
||||
bool found_second = false;
|
||||
size_t index_first = 0;
|
||||
size_t index_second = 0;
|
||||
size_t index_in_list_first = 0;
|
||||
size_t index_in_list_second = 0;
|
||||
for (size_t index = 0; index < contiguous_tensors_list_.size() && (!found_first || !found_second); index++) {
|
||||
if (!found_first) {
|
||||
auto iterator_first =
|
||||
std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_first);
|
||||
if (iterator_first != contiguous_tensors_list_[index].end()) {
|
||||
index_first = index;
|
||||
index_in_list_first = iterator_first - contiguous_tensors_list_[index].begin();
|
||||
found_first = true;
|
||||
}
|
||||
}
|
||||
if (!found_second) {
|
||||
auto iterator_second =
|
||||
std::find(contiguous_tensors_list_[index].begin(), contiguous_tensors_list_[index].end(), ref_second);
|
||||
if (iterator_second != contiguous_tensors_list_[index].end()) {
|
||||
index_second = index;
|
||||
index_in_list_second = iterator_second - contiguous_tensors_list_[index].begin();
|
||||
found_second = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_first) {
|
||||
MS_LOG(WARNING) << "Contiguous ref tensor " << ref_first << " not found in any contiguous list";
|
||||
}
|
||||
if (!found_second) {
|
||||
MS_LOG(WARNING) << "Contiguous ref tensor " << ref_second << " not found in any contiguous list";
|
||||
}
|
||||
if (contiguous_ref_list_map.find(index_first) == contiguous_ref_list_map.end() ||
|
||||
contiguous_ref_list_map[index_first] == index_second) {
|
||||
contiguous_ref_list_map[index_first] = index_second;
|
||||
// Checking for error cases
|
||||
if (index_in_list_first != index_in_list_second) {
|
||||
MS_LOG(WARNING) << "Inconsistency in contiguous ref: tensor " << ref_first << " in position "
|
||||
<< index_in_list_first << " of contiguous list " << index_first << " and tensor " << ref_second
|
||||
<< " in position " << index_in_list_second << " of contiguous list " << index_second;
|
||||
}
|
||||
contiguous_ref_list_error_check_map[index_first][index_second].insert(index_in_list_first);
|
||||
} else { // contiguous_ref_list_map.find(index_first) != contiguous_ref_list_map.end() &&
|
||||
// contiguous_ref_list_map[index_first] != index_second
|
||||
MS_LOG(WARNING) << "Contiguous list " << index_first << " associated (ref node) with two other contiguous lists: "
|
||||
<< contiguous_ref_list_map[index_first] << " and " << index_second;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto check_list_pair : contiguous_ref_list_error_check_map) {
|
||||
auto first_list = check_list_pair.first;
|
||||
auto index_set_map = check_list_pair.second;
|
||||
for (auto index_set : index_set_map) {
|
||||
auto second_list = index_set.first;
|
||||
if (contiguous_tensors_list_[first_list].size() != contiguous_tensors_list_[second_list].size()) {
|
||||
MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
|
||||
<< " considered in ref do not have the same size";
|
||||
}
|
||||
for (size_t x = 0; x < contiguous_tensors_list_[second_list].size(); x++) {
|
||||
if (contiguous_ref_list_error_check_map[first_list][second_list].count(x) == 0) {
|
||||
MS_LOG(WARNING) << "Contiguous lists " << first_list << " and " << second_list
|
||||
<< " considered in ref: ref pair at in-lists index " << x << " has not been considered";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::set<vector<size_t>> contiguous_tensors_list_to_remove;
|
||||
for (auto ref_list_pair : contiguous_ref_list_map) {
|
||||
contiguous_tensors_list_to_remove.insert(contiguous_tensors_list_[ref_list_pair.second]);
|
||||
}
|
||||
vector<vector<size_t>> contiguous_tensors_list_removed_ref = contiguous_tensors_list_;
|
||||
for (auto contiguous_list : contiguous_tensors_list_to_remove) {
|
||||
auto iterator = std::find(contiguous_tensors_list_removed_ref.begin(), contiguous_tensors_list_removed_ref.end(),
|
||||
contiguous_list);
|
||||
if (iterator != contiguous_tensors_list_removed_ref.end()) {
|
||||
contiguous_tensors_list_removed_ref.erase(iterator);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Could not find contiguous list to remove for ref";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End Solving Preprocessing for Ref Node";
|
||||
|
@ -848,8 +961,13 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
somas_solver_ = std::make_shared<SomasSolverPre>();
|
||||
somas_solver_->Solving(graph, &solver_tensor_desc_list_, cannot_reuse_, contiguous_tensors_list_, false);
|
||||
auto status =
|
||||
somas_solver_->Solving(graph, &solver_tensor_desc_list_, cannot_reuse_, contiguous_tensors_list_removed_ref, false);
|
||||
MS_LOG(INFO) << "End Solving";
|
||||
if (status != SUCCESS) {
|
||||
GenStatisticInfo();
|
||||
MS_LOG(EXCEPTION) << "SOMAS Solving Failed.";
|
||||
}
|
||||
|
||||
// Update solver_tensor_desc offset to tensors list
|
||||
for (const auto &tensor : tensors_list_) {
|
||||
|
@ -864,6 +982,15 @@ bool Somas::Assign(const session::KernelGraph *graph) {
|
|||
tensors_map_[ref_node_list[i]]->offset_ = tensors_map_[ref_node_list[0]]->offset_;
|
||||
}
|
||||
}
|
||||
// Handle contiguous ref node
|
||||
for (auto ref_list_pair : contiguous_ref_list_map) {
|
||||
size_t index_first = ref_list_pair.first;
|
||||
size_t index_second = ref_list_pair.second;
|
||||
for (size_t x = 0; x < contiguous_tensors_list_[index_second].size(); x++) {
|
||||
tensors_map_[contiguous_tensors_list_[index_second][x]]->offset_ =
|
||||
tensors_map_[contiguous_tensors_list_[index_first][x]]->offset_;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "\nEnd Solving Postprocessing for Ref Node";
|
||||
|
||||
// Set mem_offset_ value by solver result
|
||||
|
|
|
@ -38,6 +38,7 @@ SomasTensor::SomasTensor(size_t id, SomasNodePtr source_node, SomasStreamPtr sou
|
|||
|
||||
ref_overlap_ = false;
|
||||
between_streams_ = false;
|
||||
contiguous_ = false;
|
||||
num_constraints_ = 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ class SomasTensor {
|
|||
|
||||
bool ref_overlap_;
|
||||
bool between_streams_;
|
||||
bool contiguous_;
|
||||
|
||||
lifetime_t lifetime_;
|
||||
TensorType type_;
|
||||
|
|
|
@ -563,7 +563,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
|
|||
MS_LOG(INFO) << "Already malloc index:" << i;
|
||||
continue;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memeory size:" << output_sizes[i];
|
||||
MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
||||
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
|
||||
|
@ -634,7 +634,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|||
}
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
MS_LOG(DEBUG) << "Malloc memeory for " << value_node->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
|
||||
if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
|
||||
AssignValueNodeTensor(value_node, node_value, 0);
|
||||
} else if (node_value->isa<StringImm>()) {
|
||||
|
@ -693,7 +693,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
|||
// communication nodes first
|
||||
for (auto &node : execution_nodes) {
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
// skip if the memory is already alocated
|
||||
// skip if the memory is already allocated
|
||||
AssignCommunicationNodeMem(mem_type, node);
|
||||
} else {
|
||||
compute_nodes.emplace_back(node);
|
||||
|
@ -773,8 +773,8 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
|
||||
// set clean output address
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexes) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -783,12 +783,12 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
}
|
||||
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
|
||||
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
|
||||
}
|
||||
// set clean workspace address
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspaces_indexs) {
|
||||
auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspaces_indexes) {
|
||||
auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
|
|
Loading…
Reference in New Issue