solver ignore union contiguous tensors

This commit is contained in:
reku1997 2022-10-18 10:40:08 +08:00
parent d15722f1b0
commit 62b6ae7c93
2 changed files with 17 additions and 8 deletions

View File

@ -162,7 +162,6 @@ bool Somas::Assign(const session::KernelGraph &graph) {
Solve(graph);
GenGraphStatisticInfo();
if (enable_cache_) {
SaveSomasResult(graph);
}
@ -349,9 +348,6 @@ void Somas::UpdateSomasResultToGraph(const session::KernelGraph &graph) {
log_merged_blocks.emplace_back(block.start_offset_, block.size_);
all_block_size = std::max(block.start_offset_ + block.size_, all_block_size);
}
MS_EXCEPTION_IF_CHECK_FAIL(all_block_size == reused_memory_size_,
"All block size and Reused memory size not equal: " + std::to_string(all_block_size) +
", " + std::to_string(reused_memory_size_));
std::sort(log_merged_blocks.begin(), log_merged_blocks.end(),
[](const std::pair<size_t, size_t> &A, const std::pair<size_t, size_t> &B) { return A.second > B.second; });
MS_LOG(INFO) << "Merged Block size: " << log_merged_blocks.size();
@ -359,6 +355,9 @@ void Somas::UpdateSomasResultToGraph(const session::KernelGraph &graph) {
MS_LOG(INFO) << "Merged Block: " << i << ", offset: " << log_merged_blocks[i].first
<< ", size: " << log_merged_blocks[i].second;
}
MS_EXCEPTION_IF_CHECK_FAIL(all_block_size == reused_memory_size_,
"All block size and Reused memory size not equal: " + std::to_string(all_block_size) +
", " + std::to_string(reused_memory_size_));
}
bool Somas::LoadSomasResult(const string &filename) {
@ -692,9 +691,9 @@ void Somas::InitSomasOutputAndWorkspaceTensors(const session::KernelGraph &graph
if (aligned_size == 0) {
// Device Address still need to be allocated when output_size is 0
aligned_size = GetAlignSize(kZeroAlignSize);
MS_LOG(INFO) << "Node output size is zero: " << kernel->fullname_with_scope() << " output size " << size
<< " align size " << aligned_size;
}
MS_LOG(INFO) << "Node " << kernel->fullname_with_scope() << " output size " << size << " align size "
<< aligned_size;
auto tensor =
std::make_shared<SomasTensor>(output_tensor_index, node->GetId(), stream_id, size, aligned_size, kLifeLongNone);
MS_EXCEPTION_IF_NULL(tensor);
@ -1512,8 +1511,10 @@ void Somas::Solve(const session::KernelGraph &graph) {
auto status =
somas_solver_->Solving(graph, &solver_tensor_desc_map_, &reuse_matrix_, processed_contiguous_tensors_list_, false);
MS_LOG(INFO) << "End Solving";
GenGraphStatisticInfo();
if (status != SUCCESS) {
GenGraphStatisticInfo();
MS_LOG(EXCEPTION) << "SOMAS Solving Failed.";
}
@ -1728,6 +1729,14 @@ void Somas::UpdateUnionTensorsConflict() {
}
}
}
// solver should ignore union contiguous tensors.
for (auto ref_list_pair : contiguous_list_with_ref_index_map_) {
size_t index_second = ref_list_pair.second;
for (size_t x : contiguous_tensors_list_[index_second]) {
tensors_list_[x]->aligned_size_ = 0;
}
}
}
std::string Somas::GetSplitName(const std::string &scope_name) {

View File

@ -17,7 +17,7 @@ import os
import pytest
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single