!28545 Try optimize after LaunchKernelMod fail

Merge pull request !28545 from tanghuikang/optimize_after_launch_fail
This commit is contained in:
i-robot 2022-01-05 06:19:24 +00:00 committed by Gitee
commit 63f29057c9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 53 additions and 25 deletions

View File

@ -1665,6 +1665,9 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
}
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
if (mem_scheduler->optimized()) {
return;
}
mem_scheduler->SetMemHandler(mem_manager_);
mem_scheduler->SetTotalStep(graph.execution_order().size());
@ -1680,9 +1683,17 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
UseMemSchedulerIfNeeded(graph);
if (!LaunchKernelMod(graph)) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
while (!LaunchKernelMod(graph)) {
if (!UseMemScheduler()) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
}
auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
if (!mem_scheduler->Optimize()) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false;
}
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);

View File

@ -29,6 +29,7 @@ namespace {
constexpr float kMaxMemReuseFactor = 1.0;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
constexpr size_t kMockTimes = 3;
double GetCurrentTime() {
#ifdef _MSC_VER
@ -232,37 +233,49 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
bool MemScheduler::Optimize() {
AdjustFirstEventIndex();
float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
float mem_used_factor = optimized_ ? mem_used_factor_ - kRetryFactor : kMaxMemReuseFactor;
while (mem_used_factor >= kMinMemReuseFactor) {
OptMemUsage(mem_used_factor);
current_step_ = 0;
bool ret = true;
for (size_t step = 0; step < total_step_; ++step) {
ret = PreCompute(nullptr);
auto &step_events = step_events_[step];
for (auto &event : step_events) {
if (event->type != kGet) {
continue;
}
auto ptr = GetOrMalloc(event->key, event->mem_size);
if (ptr == nullptr) {
ret = false;
break;
}
}
for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
ret = MockOneStep();
if (!ret) {
break;
}
PostCompute(nullptr);
}
if (ret) {
optimized_ = true;
} else {
ClearAllocatedMem();
mem_used_factor -= kRetryFactor;
return true;
}
ClearAllocatedMem();
mem_used_factor -= kRetryFactor;
}
return false;
}
bool MemScheduler::MockOneStep() {
current_step_ = 0;
for (size_t step = 0; step < total_step_; ++step) {
bool ret = PreCompute(nullptr);
if (!ret) {
return false;
}
auto &step_events = step_events_[step];
for (auto &event : step_events) {
if (event->type != kGet) {
continue;
}
auto ptr = GetOrMalloc(event->key, event->mem_size);
if (ptr == nullptr) {
return false;
}
}
ret = PostCompute(nullptr);
if (!ret) {
return false;
}
}
return optimized_;
return true;
}
void MemScheduler::AdjustFirstEventIndex() {

View File

@ -45,6 +45,8 @@ class MemScheduler {
void set_need_record_event(bool flag) { need_record_event_ = flag; }
bool optimized() const { return optimized_; }
void Update();
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
@ -89,6 +91,8 @@ class MemScheduler {
void OptMemUsage(float mem_used_factor = 1.0f);
bool MockOneStep();
void AdjustFirstEventIndex();
std::map<const void *, MemPriority> mem_priority_;
@ -104,7 +108,7 @@ class MemScheduler {
size_t current_step_{0};
bool need_record_event_{true};
bool optimized_{false};
float mem_used_factor_{0.9};
float mem_used_factor_{1.0};
double compute_start_time_{0};
std::vector<double> compute_time_;
bool record_compute_time_{false};