From: @HulkTang
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-04-02 09:11:15 +08:00 committed by Gitee
commit cfe336e54e
8 changed files with 5 additions and 92 deletions

View File

@ -202,7 +202,9 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type);
stub_output_tensor->set_device_address(device_address);
output_tensor_info.output_stub_tensor = stub_output_tensor;
output_tensor_info.is_weight = !dynamic_cast<device::KernelInfo *>(output_node->kernel_info())->is_feature_map();
auto kernel_info = dynamic_cast<const device::KernelInfo *>(output_node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
output_tensor_info.is_weight = !(kernel_info->is_feature_map());
(*op_output_info)[kernel_with_index] = output_tensor_info;
}
}

View File

@ -113,20 +113,10 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node);
}
void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
}
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
SyncValueNodeDeviceAddr(kernel_graph);
MS_LOG(INFO) << "Bind input output address";
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);

View File

@ -50,7 +50,6 @@ class CPUSession : public SessionBasic {
void SetKernelInfo(const KernelGraph *kernel_graph);
void BuildKernel(const KernelGraph *kernel_graph);
void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors);
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph);
device::cpu::CPUKernelRuntime runtime_;
};
MS_REG_SESSION(kCPUDevice, CPUSession);

View File

@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
VectorRef *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
// In pynative mode, device addresses of tensors in value nodes change.
SyncValueNodeDeviceAddr(kernel_graph);
// Load input data from user input
LoadInputData(kernel_graph, inputs);
if (debugger_) {
@ -449,8 +447,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
#endif
Execute(kernel_graph);
}
// In pynative mode, device addresses of tensors in value nodes need be clean.
CleanValueNodeDeviceAddr(kernel_graph);
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -540,28 +536,6 @@ void GPUSession::PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_gra
}
}
void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get());
}
void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get());
}
void GPUSession::SyncStream() {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);

View File

@ -82,10 +82,6 @@ class GPUSession : public SessionBasic {
void PostIterationDbg(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
GraphId CompileGraphImpl(KernelGraphPtr kernel_graph);
};
using GPUSessionPtr = std::shared_ptr<GPUSession>;

View File

@ -404,7 +404,7 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
return true;
}
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
std::map<AnfNodePtr, size_t> *parameter_index) {
size_t index = 0;
for (const auto &input_node : graph->inputs()) {
@ -512,7 +512,7 @@ void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vect
}
}
void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {

View File

@ -712,52 +712,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_LOG(INFO) << "AssignStaticMemoryValueNode end";
}
void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "SyncValueNodeDeviceAddr start";
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
continue;
}
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t index = 0; index < tensors.size(); index += 1) {
const auto &tensor = tensors[index];
if (tensor->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index,
value_node.get());
} else {
MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr.";
}
}
}
MS_LOG(INFO) << "SyncValueNodeDeviceAddr end";
}
void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "CleanValueNodeDeviceAddr start";
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
continue;
}
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t index = 0; index < tensors.size(); index += 1) {
if (tensors[index]->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(nullptr, index, value_node.get());
}
}
}
MS_LOG(INFO) << "CleanValueNodeDeviceAddr end";
}
void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_manager_);

View File

@ -67,8 +67,6 @@ class KernelRuntime {
const AddressPtrList &kernel_workspaces) const;
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph);
virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order);