forked from mindspore-Ecosystem/mindspore
format some executor func name
This commit is contained in:
parent
eec1b4441d
commit
54ef8520ab
|
@ -1517,11 +1517,6 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
|
|||
}
|
||||
}
|
||||
|
||||
GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
|
||||
RunInfer(func_graph, inputs);
|
||||
return CompileGraphImpl(func_graph);
|
||||
}
|
||||
|
||||
void AscendSession::SyncStream() {
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
|
|
|
@ -49,7 +49,6 @@ class AscendSession : public SessionBasic {
|
|||
void UnifyMindIR(const KernelGraphPtr &graph) override;
|
||||
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
|
||||
GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
|
||||
bool IsSupportSummary() override;
|
||||
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
||||
void BuildGraphImpl(GraphId) override;
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
|||
void UpdateOutputTensors(const VectorRef *outputs,
|
||||
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto item : *outputs) {
|
||||
for (auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
auto vector_ref = utils::cast<VectorRef>(item);
|
||||
UpdateOutputTensors(&vector_ref, tensor_to_node);
|
||||
|
@ -45,7 +45,6 @@ void UpdateOutputTensors(const VectorRef *outputs,
|
|||
auto &output_index = iter->second.second;
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
tensor->set_device_address(address);
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
ShapeVector int_shape;
|
||||
|
@ -62,12 +61,12 @@ void UpdateOutputTensors(const VectorRef *outputs,
|
|||
}
|
||||
}
|
||||
|
||||
void NotifyOutputTensors(const VectorRef *outputs) {
|
||||
void SetOutputTensorsWaitStatus(const VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto item : *outputs) {
|
||||
for (auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
auto vector_ref = utils::cast<VectorRef>(item);
|
||||
NotifyOutputTensors(&vector_ref);
|
||||
SetOutputTensorsWaitStatus(&vector_ref);
|
||||
} else if (utils::isa<tensor::TensorPtr>(item)) {
|
||||
auto tensor = utils::cast<tensor::TensorPtr>(item);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -78,7 +77,7 @@ void NotifyOutputTensors(const VectorRef *outputs) {
|
|||
|
||||
bool TensorInVector(const VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto item : *outputs) {
|
||||
for (auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
auto vector_ref = utils::cast<VectorRef>(item);
|
||||
if (TensorInVector(&vector_ref)) {
|
||||
|
@ -90,6 +89,50 @@ bool TensorInVector(const VectorRef *outputs) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
||||
MS_EXCEPTION_IF_NULL(task);
|
||||
for (auto &input : task->input_need_wait_tensors_) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->NeedWait()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto session = task->session_;
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
return graph->IsPreGraphFinished();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
|
||||
bool need_lock = false;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait()) {
|
||||
if (tensor->IsGraphOutput()) {
|
||||
task->input_need_wait_tensors_.emplace_back(tensor);
|
||||
} else {
|
||||
need_lock = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_lock) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
|
||||
MsException::Instance().CheckException();
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
// need lock input parameters for optimizer
|
||||
for (auto &tensor : task->input_need_lock_tensors_) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CompileNodesTask::Run() {
|
||||
|
@ -129,7 +172,7 @@ void RunGraphTask::Run() {
|
|||
for (auto &tensor : input_need_lock_tensors_) {
|
||||
tensor->SetNeedWait(false);
|
||||
}
|
||||
NotifyOutputTensors(&outputs_);
|
||||
SetOutputTensorsWaitStatus(&outputs_);
|
||||
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
|
||||
}
|
||||
|
||||
|
@ -198,7 +241,7 @@ void Executor::WorkerLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
|
||||
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
|
||||
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
|
||||
|
@ -249,7 +292,7 @@ void Executor::OnException() {
|
|||
}
|
||||
|
||||
void Executor::OnRunGraphFinished() {
|
||||
auto new_ready_tasks = GetNewReadyTasks();
|
||||
auto new_ready_tasks = GetReadyTasksFromPendingList();
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
for (auto &task : new_ready_tasks) {
|
||||
ready_tasks_.push(task);
|
||||
|
@ -260,23 +303,6 @@ void Executor::OnRunGraphFinished() {
|
|||
reenter_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
||||
MS_EXCEPTION_IF_NULL(task);
|
||||
for (auto &input : task->input_need_wait_tensors_) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->NeedWait()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto session = task->session_;
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
return graph->IsPreGraphFinished();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Executor::ClearDoneTasks() {
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
done_tasks_.clear();
|
||||
|
@ -341,33 +367,6 @@ void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
|
|||
RunTask(task, true, true);
|
||||
}
|
||||
|
||||
void Executor::WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task) {
|
||||
bool need_lock = false;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait()) {
|
||||
if (tensor->IsGraphOutput()) {
|
||||
task->input_need_wait_tensors_.emplace_back(tensor);
|
||||
} else {
|
||||
need_lock = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_lock) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
|
||||
MsException::Instance().CheckException();
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
// need lock input parameters for optimizer
|
||||
for (auto &tensor : task->input_need_lock_tensors_) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
|
|
|
@ -172,9 +172,7 @@ class Executor {
|
|||
|
||||
private:
|
||||
void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
|
||||
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
|
||||
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
|
||||
void WaitLockedInputs(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task);
|
||||
std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
|
||||
void OnWorkerExit();
|
||||
void OnClear();
|
||||
void OnRunGraphFinished();
|
||||
|
|
|
@ -118,68 +118,16 @@ ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
|
|||
return parameter->param_info();
|
||||
}
|
||||
|
||||
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
||||
const KernelGraphPtr &graph) {
|
||||
auto &node = node_output_pair.first;
|
||||
auto &output_index = node_output_pair.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
||||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
}
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
std::vector<int64_t> temp_shape;
|
||||
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
tensor->SetNeedWait(true);
|
||||
tensor->SetIsGraphOutput();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
tensor = graph->GetInternalOutputTensor(node, output_index);
|
||||
if (tensor == nullptr) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
bool is_internal_output = graph->IsInternalOutput(node, output_index);
|
||||
if (is_internal_output) {
|
||||
graph->AddInternalOutputTensor(node, output_index, tensor);
|
||||
}
|
||||
}
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
// if in pynative mode,data only copied to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
||||
}
|
||||
tensor->SetNeedWait(true);
|
||||
tensor->SetIsGraphOutput();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
static bool IsPynativeMode() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
|
||||
}
|
||||
|
||||
BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
|
||||
BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
auto &node = node_output_pair.first;
|
||||
auto &output_index = node_output_pair.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(tensor_to_node);
|
||||
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]";
|
||||
if (HasAbstractMonad(node)) {
|
||||
return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
|
||||
}
|
||||
|
@ -189,7 +137,8 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
|||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return value_node->value();
|
||||
}
|
||||
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, output_index);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool output_addr_exist = AnfAlgo::OutputAddrExist(node, node_output_pair.second);
|
||||
if (!output_addr_exist || (CheckIfNeedCreateOutputTensor(node) && !IsPynativeMode())) {
|
||||
if (node->isa<Parameter>()) {
|
||||
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
|
||||
|
@ -205,7 +154,56 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
|||
}
|
||||
}
|
||||
}
|
||||
auto tensor = CreateCNodeOutputTensor(node_output_pair, graph);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
|
||||
auto &node = node_output_pair.first;
|
||||
auto &output_index = node_output_pair.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
|
||||
auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
|
||||
if (tensor_from_input != nullptr) {
|
||||
return tensor_from_input;
|
||||
}
|
||||
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
||||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
}
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
std::vector<int64_t> temp_shape;
|
||||
if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
|
||||
temp_shape.emplace_back(1);
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
} else {
|
||||
tensor = graph->GetInternalOutputTensor(node, output_index);
|
||||
if (tensor == nullptr) {
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
bool is_internal_output = graph->IsInternalOutput(node, output_index);
|
||||
if (is_internal_output) {
|
||||
graph->AddInternalOutputTensor(node, output_index, tensor);
|
||||
}
|
||||
}
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
// if in pynative mode,data only copied to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
||||
}
|
||||
}
|
||||
tensor->SetNeedWait(true);
|
||||
tensor->SetIsGraphOutput();
|
||||
(*tensor_to_node)[tensor] = node_output_pair;
|
||||
return tensor;
|
||||
}
|
||||
|
@ -1778,43 +1776,6 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
|||
summary_callback_ = callback;
|
||||
}
|
||||
|
||||
void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
size_t tensor_index = 0;
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
AbstractBasePtrList input_abstracts;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto abstract = input_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
input_abstracts.emplace_back(abstract);
|
||||
}
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(node);
|
||||
if (prim->isa<ops::PrimitiveC>()) {
|
||||
auto prim_c = prim->cast<std::shared_ptr<ops::PrimitiveC>>();
|
||||
MS_EXCEPTION_IF_NULL(prim_c);
|
||||
auto abstract = prim_c->Infer(input_abstracts);
|
||||
node->set_abstract(abstract);
|
||||
}
|
||||
} else if (node->isa<Parameter>()) {
|
||||
if (tensor_index > inputs.size()) {
|
||||
MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
|
||||
}
|
||||
node->set_abstract(inputs[tensor_index++]->ToAbstract());
|
||||
} else {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
value_node->set_abstract(value->ToAbstract());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
||||
MS_LOG(DEBUG) << "Update summary Start";
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -166,9 +166,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void UnifyMindIR(const KernelGraphPtr &graph) = 0;
|
||||
virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
|
||||
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
|
||||
virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
|
||||
MS_EXCEPTION(NotExistsError) << "Call an empty function";
|
||||
}
|
||||
virtual void BuildGraphImpl(GraphId) {}
|
||||
virtual void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) = 0;
|
||||
|
@ -182,8 +179,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
|
||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||
|
||||
virtual void SetSummaryNodes(KernelGraph *graph);
|
||||
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
|
|
|
@ -131,7 +131,7 @@ void AscendKernelRuntime::SetContext() {
|
|||
}
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::InnerSetContext() {
|
||||
void AscendKernelRuntime::SetCurrentContext() {
|
||||
if (rt_context_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ void AscendKernelRuntime::InnerSetContext() {
|
|||
}
|
||||
|
||||
void AscendKernelRuntime::ClearGraphModelMap() {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
for (auto &iter : graph_data_dumper_) {
|
||||
MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
|
||||
auto &data_dumper = iter.second;
|
||||
|
@ -168,7 +168,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
|
|||
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &,
|
||||
const std::unordered_set<ValueNodePtr> &,
|
||||
const std::vector<CNodePtr> &) {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
|
||||
if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
|
||||
MS_LOG(DEBUG) << "Unload dump info " << graph_id;
|
||||
|
@ -247,7 +247,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
if (!initialized_) {
|
||||
return;
|
||||
}
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
ReportProfilingData();
|
||||
// release ge runtime
|
||||
ClearGraphModelMap();
|
||||
|
@ -284,7 +284,7 @@ void AscendKernelRuntime::PreInit() {
|
|||
|
||||
bool AscendKernelRuntime::Init() {
|
||||
if (initialized_) {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
return true;
|
||||
}
|
||||
OpTilingCalculater::GetInstance().Init();
|
||||
|
@ -437,7 +437,7 @@ bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
|
|||
|
||||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
if (graph->is_dynamic_shape()) {
|
||||
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
|
||||
MS_LOG(EXCEPTION) << "Dynamic shape is not supported with sink mode.";
|
||||
|
@ -498,7 +498,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
|
||||
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
if (graph->is_dynamic_shape()) {
|
||||
MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
|
||||
return true;
|
||||
|
@ -716,7 +716,7 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
|
|||
|
||||
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
||||
current_graph_ = graph;
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->is_dynamic_shape()) {
|
||||
MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
|
||||
|
@ -761,7 +761,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
bool AscendKernelRuntime::SyncStream() {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
if (stream_ == nullptr) {
|
||||
MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr";
|
||||
return false;
|
||||
|
@ -779,7 +779,7 @@ bool AscendKernelRuntime::SyncStream() {
|
|||
}
|
||||
|
||||
bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
if (stream_ == nullptr) {
|
||||
MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr";
|
||||
return false;
|
||||
|
@ -803,7 +803,7 @@ void AscendKernelRuntime::CreateContext() {
|
|||
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
}
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::InitDevice() {
|
||||
|
@ -850,7 +850,7 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
}
|
||||
|
||||
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
||||
InnerSetContext();
|
||||
SetCurrentContext();
|
||||
if (stream_ != nullptr) {
|
||||
auto ret = rtStreamDestroy(stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
|
|
|
@ -76,7 +76,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
static bool NeedDestroyHccl();
|
||||
static bool DestroyHccl();
|
||||
static bool DestroySingleOpHccl();
|
||||
void InnerSetContext();
|
||||
void SetCurrentContext();
|
||||
|
||||
void ClearGraphModelMap();
|
||||
void ReleaseDeviceRes() override;
|
||||
|
|
|
@ -121,10 +121,8 @@ class KernelRuntime {
|
|||
|
||||
void AssignStaticMemory(session::KernelGraph *graph);
|
||||
void AssignDynamicMemory(session::KernelGraph *graph);
|
||||
void ReuseAssignDynamicMemory(session::KernelGraph *graph);
|
||||
void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index);
|
||||
void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node);
|
||||
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
|
||||
|
||||
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
|
||||
|
||||
|
|
Loading…
Reference in New Issue