forked from mindspore-Ecosystem/mindspore
!705 add pynative cache
Merge pull request !705 from chujinjin/add_pynative_cache
This commit is contained in:
commit
64abbeaa89
|
@ -22,7 +22,7 @@ namespace mindspore {
|
|||
namespace device {
|
||||
namespace ascend {
|
||||
const uint64_t kAscendDeviceMemGB = 20;
|
||||
const uint64_t kAscendMemPoolGB = 5;
|
||||
const uint64_t kAscendMemPoolGB = 10;
|
||||
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
|
||||
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "parallel/graph_util/get_parallel_info.h"
|
||||
#include "device/kernel_runtime_manager.h"
|
||||
#include "debug/trace.h"
|
||||
#include "pynative/pynative_execute.h"
|
||||
|
||||
#if (ENABLE_GE || ENABLE_D)
|
||||
#include "pipeline/pipeline_ge.h"
|
||||
|
@ -829,6 +830,7 @@ void FinalizeBackend() {
|
|||
|
||||
void ClearResAtexit() {
|
||||
MS_LOG(DEBUG) << "Pipeline clear all resource";
|
||||
pynative::ClearPyNativeSession();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
|
||||
ad::g_k_prims.clear();
|
||||
|
|
|
@ -44,6 +44,7 @@ const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "ze
|
|||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
static std::shared_ptr<session::SessionBasic> session = nullptr;
|
||||
inline ValuePtr PyAttrValue(const py::object &obj) {
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_ret);
|
||||
|
@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
if (device_target != kAscendDevice && device_target != kGPUDevice) {
|
||||
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
|
||||
}
|
||||
std::shared_ptr<session::SessionBasic> session = session::SessionFactory::Get().Create(device_target);
|
||||
|
||||
if (session == nullptr) {
|
||||
session = session::SessionFactory::Get().Create(device_target);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
session->Init(ms_context->device_id());
|
||||
|
||||
|
@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) {
|
|||
MS_LOG(INFO) << "RunOp end";
|
||||
return result;
|
||||
}
|
||||
|
||||
void ClearPyNativeSession() { session = nullptr; }
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,9 @@ namespace py = pybind11;
|
|||
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
|
||||
py::tuple RunOp(const py::args &args);
|
||||
|
||||
void ClearPyNativeSession();
|
||||
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
|
||||
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<bool> &tensors_mask) {
|
||||
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
|
||||
if (GraphCacheExist(graph_info)) {
|
||||
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
|
||||
return;
|
||||
}
|
||||
|
||||
// construct graph include one op
|
||||
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
|
|||
RunOpAdjustKernel(graph);
|
||||
BuildKernel(graph);
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
|
||||
}
|
||||
|
||||
py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
|
@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
|
|||
}
|
||||
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
|
||||
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
|
||||
run_op_graphs_.clear();
|
||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
||||
return tuple_tensors;
|
||||
}
|
||||
|
|
|
@ -111,6 +111,8 @@ class AscendSession : public SessionBasic {
|
|||
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
|
||||
// copy output of if and else
|
||||
void CopyOutputOfIf(GraphId false_graph_id);
|
||||
// check if graph cache exist
|
||||
bool GraphCacheExist(const GraphInfo &graph_info) const;
|
||||
|
||||
// member variables
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
|
|
|
@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
// if in paynative mode,data only copyed to host when user want to print data
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->enable_pynative_infer()) {
|
||||
if (ms_context->execution_mode() == kPynativeMode) {
|
||||
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
|
||||
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
|
|
Loading…
Reference in New Issue