forked from mindspore-Ecosystem/mindspore
run dynamic shape in existing context
This commit is contained in:
parent
58e5d8b3f0
commit
7a2a714a11
|
@ -316,6 +316,7 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
|||
if (!is_task_sink) {
|
||||
return true;
|
||||
}
|
||||
rtCtxSetCurrent(rt_context_hccl_);
|
||||
// Do HcomExecutorInitialize
|
||||
if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) {
|
||||
MS_LOG(ERROR) << "Init Hccl Executor Failed";
|
||||
|
@ -651,6 +652,11 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
}
|
||||
}
|
||||
|
||||
ret = rtCtxGetCurrent(&rt_context_hccl_);
|
||||
if (ret != RT_ERROR_NONE || rt_context_hccl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]";
|
||||
}
|
||||
|
||||
ret = rtCtxCreate(&rt_context_, 0, device_id_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
|
||||
|
|
|
@ -76,6 +76,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
void LaunchDataDump(GraphId graph_id);
|
||||
|
||||
rtContext_t rt_context_{nullptr};
|
||||
rtContext_t rt_context_hccl_{nullptr};
|
||||
bool initialized_{false};
|
||||
unordered_map<GraphId, vector<std::shared_ptr<TaskInfo>>> task_map_;
|
||||
unordered_map<GraphId, std::shared_ptr<ge::model_runner::DavinciModel>> graph_model_map_;
|
||||
|
|
|
@ -87,11 +87,7 @@ void HcclDynamicKernel::StaticShapeExecute() {
|
|||
|
||||
void HcclDynamicKernel::Execute() {
|
||||
MS_LOG(INFO) << "Start Execute";
|
||||
if (!is_dynamic_shape_) {
|
||||
MS_LOG(INFO) << "Not Dynamic, call hcom api";
|
||||
StaticShapeExecute();
|
||||
return;
|
||||
}
|
||||
|
||||
auto handle = HcclExecutorManager::GetInstance().handle();
|
||||
auto EnqueueHcomOperation =
|
||||
(HcclResult(*)(ge::HcomOpertion, std::function<void(HcclResult status)>))dlsym(handle, "EnqueueHcomOpertion");
|
||||
|
|
Loading…
Reference in New Issue