!14538 clean code
From: @HulkTang Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @zhoufeng54
This commit is contained in:
commit
cfe336e54e
|
@ -202,7 +202,9 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type);
|
||||
stub_output_tensor->set_device_address(device_address);
|
||||
output_tensor_info.output_stub_tensor = stub_output_tensor;
|
||||
output_tensor_info.is_weight = !dynamic_cast<device::KernelInfo *>(output_node->kernel_info())->is_feature_map();
|
||||
auto kernel_info = dynamic_cast<const device::KernelInfo *>(output_node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
output_tensor_info.is_weight = !(kernel_info->is_feature_map());
|
||||
(*op_output_info)[kernel_with_index] = output_tensor_info;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,20 +113,10 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
|
|||
runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node);
|
||||
}
|
||||
|
||||
void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
return;
|
||||
}
|
||||
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
|
||||
}
|
||||
|
||||
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SyncValueNodeDeviceAddr(kernel_graph);
|
||||
MS_LOG(INFO) << "Bind input output address";
|
||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||
|
||||
|
|
|
@ -50,7 +50,6 @@ class CPUSession : public SessionBasic {
|
|||
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||
void BuildKernel(const KernelGraph *kernel_graph);
|
||||
void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors);
|
||||
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
device::cpu::CPUKernelRuntime runtime_;
|
||||
};
|
||||
MS_REG_SESSION(kCPUDevice, CPUSession);
|
||||
|
|
|
@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
|
|||
VectorRef *outputs) {
|
||||
auto &kernel_graph = graphs_[graph_id];
|
||||
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
|
||||
// In pynative mode, device addresses of tensors in value nodes change.
|
||||
SyncValueNodeDeviceAddr(kernel_graph);
|
||||
// Load input data from user input
|
||||
LoadInputData(kernel_graph, inputs);
|
||||
if (debugger_) {
|
||||
|
@ -449,8 +447,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
|
|||
#endif
|
||||
Execute(kernel_graph);
|
||||
}
|
||||
// In pynative mode, device addresses of tensors in value nodes need be clean.
|
||||
CleanValueNodeDeviceAddr(kernel_graph);
|
||||
// Summary
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -540,28 +536,6 @@ void GPUSession::PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
}
|
||||
}
|
||||
|
||||
void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
return;
|
||||
}
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get());
|
||||
}
|
||||
|
||||
void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
return;
|
||||
}
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get());
|
||||
}
|
||||
|
||||
void GPUSession::SyncStream() {
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
|
|
|
@ -82,10 +82,6 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
|
||||
};
|
||||
using GPUSessionPtr = std::shared_ptr<GPUSession>;
|
||||
|
|
|
@ -404,7 +404,7 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||
void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||
std::map<AnfNodePtr, size_t> *parameter_index) {
|
||||
size_t index = 0;
|
||||
for (const auto &input_node : graph->inputs()) {
|
||||
|
@ -512,7 +512,7 @@ void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vect
|
|||
}
|
||||
}
|
||||
|
||||
void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
|
||||
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
|
||||
|
|
|
@ -712,52 +712,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|||
MS_LOG(INFO) << "AssignStaticMemoryValueNode end";
|
||||
}
|
||||
|
||||
void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "SyncValueNodeDeviceAddr start";
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
for (size_t index = 0; index < tensors.size(); index += 1) {
|
||||
const auto &tensor = tensors[index];
|
||||
if (tensor->device_address() != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index,
|
||||
value_node.get());
|
||||
} else {
|
||||
MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr.";
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "SyncValueNodeDeviceAddr end";
|
||||
}
|
||||
|
||||
void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "CleanValueNodeDeviceAddr start";
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto &node_value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(node_value);
|
||||
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(node_value, &tensors);
|
||||
for (size_t index = 0; index < tensors.size(); index += 1) {
|
||||
if (tensors[index]->device_address() != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(nullptr, index, value_node.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "CleanValueNodeDeviceAddr end";
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
|
|
|
@ -67,8 +67,6 @@ class KernelRuntime {
|
|||
const AddressPtrList &kernel_workspaces) const;
|
||||
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
|
||||
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
|
||||
virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph);
|
||||
virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph);
|
||||
virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
|
||||
const std::unordered_set<ValueNodePtr> &value_nodes,
|
||||
const std::vector<CNodePtr> &execution_order);
|
||||
|
|
Loading…
Reference in New Issue