forked from mindspore-Ecosystem/mindspore
!8109 add cpu pynative mode
From: @chujinjin Reviewed-by: @kisnwang Signed-off-by:
This commit is contained in:
commit
98d6198e6b
|
@ -93,10 +93,20 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector<
|
|||
runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node);
|
||||
}
|
||||
|
||||
void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
return;
|
||||
}
|
||||
runtime_.SyncValueNodeDeviceAddr(kernel_graph.get());
|
||||
}
|
||||
|
||||
void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SyncValueNodeDeviceAddr(kernel_graph);
|
||||
MS_LOG(INFO) << "Bind input output address";
|
||||
runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs);
|
||||
|
||||
|
@ -130,6 +140,65 @@ void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
|
|||
MS_LOG(INFO) << "Run graph end";
|
||||
}
|
||||
|
||||
void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
// Check if the graph cache exists.
|
||||
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
|
||||
return;
|
||||
}
|
||||
// Prepare the graph
|
||||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SetKernelInfo(kernel_graph.get());
|
||||
BuildKernel(kernel_graph.get());
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
|
||||
void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors) {
|
||||
for (size_t i = 0; i < base_ref.size(); ++i) {
|
||||
if (utils::isa<VectorRef>(base_ref[i])) {
|
||||
auto ref_iter = utils::cast<VectorRef>(base_ref[i]);
|
||||
SetOutputFlags(ref_iter, outputs_tensors);
|
||||
} else if (utils::isa<tensor::TensorPtr>(base_ref[i])) {
|
||||
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref[i]);
|
||||
tensor_ptr->SetNeedWait(false);
|
||||
outputs_tensors->push_back(tensor_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
runtime_.AssignKernelAddress(kernel_graph.get());
|
||||
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
|
||||
runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node);
|
||||
runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs);
|
||||
|
||||
MS_LOG(INFO) << "Run Op start";
|
||||
auto execution_order = kernel_graph->execution_order();
|
||||
Reorder(&execution_order);
|
||||
|
||||
kernel_graph->set_execution_order(execution_order);
|
||||
|
||||
bool ret = runtime_.Run(kernel_graph.get(), false);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Run Op failed";
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> output_tensors;
|
||||
SetOutputFlags(*outputs, &output_tensors);
|
||||
MS_LOG(INFO) << "Run Op end";
|
||||
}
|
||||
|
||||
void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto &kernel_nodes = kernel_graph->execution_order();
|
||||
|
|
|
@ -38,10 +38,18 @@ class CPUSession : public SessionBasic {
|
|||
void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override;
|
||||
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
|
||||
private:
|
||||
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||
void BuildKernel(const KernelGraph *kernel_graph);
|
||||
void SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors);
|
||||
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
device::cpu::CPUKernelRuntime runtime_;
|
||||
};
|
||||
MS_REG_SESSION(kCPUDevice, CPUSession);
|
||||
|
|
|
@ -994,6 +994,8 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
|
|||
});
|
||||
return;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
|
||||
for (size_t i = 0; i < tensor_id_list.size(); ++i) {
|
||||
auto tensor_id = tensor_id_list[i];
|
||||
|
@ -1003,7 +1005,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex
|
|||
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
|
||||
tensor->set_shape(new_tensor->shape());
|
||||
tensor->set_data_type(new_tensor->data_type());
|
||||
tensor->set_device_address(new_tensor->device_address());
|
||||
if (target != kCPUDevice) {
|
||||
tensor->set_device_address(new_tensor->device_address());
|
||||
} else {
|
||||
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
|
||||
auto old_ptr = old_device_address->GetMutablePtr();
|
||||
auto new_ptr = new_device_address->GetPtr();
|
||||
MS_EXCEPTION_IF_NULL(old_ptr);
|
||||
MS_EXCEPTION_IF_NULL(new_ptr);
|
||||
auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -1264,12 +1279,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
|||
MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms";
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (device_target != kAscendDevice && device_target != kGPUDevice) {
|
||||
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
|
||||
}
|
||||
|
||||
if (session == nullptr) {
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
session = session::SessionFactory::Get().Create(device_target);
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->Init(ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID));
|
||||
|
|
|
@ -56,6 +56,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph
|
|||
}
|
||||
auto tensor = node_value->cast<TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
if (tensor->device_address() != nullptr) {
|
||||
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), 0,
|
||||
item_node.get());
|
||||
continue;
|
||||
}
|
||||
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0);
|
||||
if (output_type_id == kTypeUnknown) {
|
||||
output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);
|
||||
|
|
Loading…
Reference in New Issue