forked from mindspore-Ecosystem/mindspore
!28545 Try optimize after LaunchKernelMod fail
Merge pull request !28545 from tanghuikang/optimize_after_launch_fail
This commit is contained in:
commit
63f29057c9
|
@ -1665,6 +1665,9 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
|||
}
|
||||
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
if (mem_scheduler->optimized()) {
|
||||
return;
|
||||
}
|
||||
mem_scheduler->SetMemHandler(mem_manager_);
|
||||
mem_scheduler->SetTotalStep(graph.execution_order().size());
|
||||
|
||||
|
@ -1680,9 +1683,17 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
|||
|
||||
bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
|
||||
UseMemSchedulerIfNeeded(graph);
|
||||
if (!LaunchKernelMod(graph)) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
while (!LaunchKernelMod(graph)) {
|
||||
if (!UseMemScheduler()) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
}
|
||||
auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
if (!mem_scheduler->Optimize()) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace {
|
|||
constexpr float kMaxMemReuseFactor = 1.0;
|
||||
constexpr float kMinMemReuseFactor = 0.5;
|
||||
constexpr float kRetryFactor = 0.1;
|
||||
constexpr size_t kMockTimes = 3;
|
||||
|
||||
double GetCurrentTime() {
|
||||
#ifdef _MSC_VER
|
||||
|
@ -232,37 +233,49 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
|||
|
||||
bool MemScheduler::Optimize() {
|
||||
AdjustFirstEventIndex();
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||
float mem_used_factor = optimized_ ? mem_used_factor_ - kRetryFactor : kMaxMemReuseFactor;
|
||||
while (mem_used_factor >= kMinMemReuseFactor) {
|
||||
OptMemUsage(mem_used_factor);
|
||||
current_step_ = 0;
|
||||
bool ret = true;
|
||||
for (size_t step = 0; step < total_step_; ++step) {
|
||||
ret = PreCompute(nullptr);
|
||||
auto &step_events = step_events_[step];
|
||||
for (auto &event : step_events) {
|
||||
if (event->type != kGet) {
|
||||
continue;
|
||||
}
|
||||
auto ptr = GetOrMalloc(event->key, event->mem_size);
|
||||
if (ptr == nullptr) {
|
||||
ret = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
|
||||
ret = MockOneStep();
|
||||
if (!ret) {
|
||||
break;
|
||||
}
|
||||
PostCompute(nullptr);
|
||||
}
|
||||
if (ret) {
|
||||
optimized_ = true;
|
||||
} else {
|
||||
ClearAllocatedMem();
|
||||
mem_used_factor -= kRetryFactor;
|
||||
return true;
|
||||
}
|
||||
ClearAllocatedMem();
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MemScheduler::MockOneStep() {
|
||||
current_step_ = 0;
|
||||
for (size_t step = 0; step < total_step_; ++step) {
|
||||
bool ret = PreCompute(nullptr);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
auto &step_events = step_events_[step];
|
||||
for (auto &event : step_events) {
|
||||
if (event->type != kGet) {
|
||||
continue;
|
||||
}
|
||||
auto ptr = GetOrMalloc(event->key, event->mem_size);
|
||||
if (ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ret = PostCompute(nullptr);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return optimized_;
|
||||
return true;
|
||||
}
|
||||
|
||||
void MemScheduler::AdjustFirstEventIndex() {
|
||||
|
|
|
@ -45,6 +45,8 @@ class MemScheduler {
|
|||
|
||||
void set_need_record_event(bool flag) { need_record_event_ = flag; }
|
||||
|
||||
bool optimized() const { return optimized_; }
|
||||
|
||||
void Update();
|
||||
|
||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||
|
@ -89,6 +91,8 @@ class MemScheduler {
|
|||
|
||||
void OptMemUsage(float mem_used_factor = 1.0f);
|
||||
|
||||
bool MockOneStep();
|
||||
|
||||
void AdjustFirstEventIndex();
|
||||
|
||||
std::map<const void *, MemPriority> mem_priority_;
|
||||
|
@ -104,7 +108,7 @@ class MemScheduler {
|
|||
size_t current_step_{0};
|
||||
bool need_record_event_{true};
|
||||
bool optimized_{false};
|
||||
float mem_used_factor_{0.9};
|
||||
float mem_used_factor_{1.0};
|
||||
double compute_start_time_{0};
|
||||
std::vector<double> compute_time_;
|
||||
bool record_compute_time_{false};
|
||||
|
|
Loading…
Reference in New Issue