forked from mindspore-Ecosystem/mindspore
!9390 Pynative support dynamic op run in gpu
From: @joylvliang Reviewed-by: @chujinjin,@jjfeing Signed-off-by: @chujinjin
This commit is contained in:
commit
1a5dd4a711
|
@ -700,20 +700,21 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
|
|||
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
|
||||
}
|
||||
|
||||
void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// Run op
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
|
||||
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!";
|
||||
// malloc mem
|
||||
RunOpMemoryAlloc(*input_tensors, graph.get());
|
||||
// Build dynamic kernel
|
||||
if (op_run_info.is_dynamic_shape) {
|
||||
if (op_run_info->is_dynamic_shape) {
|
||||
BuildDynamicKernel(graph);
|
||||
}
|
||||
// load input data to device
|
||||
|
@ -722,8 +723,12 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
|
|||
Execute(graph, false);
|
||||
// get output
|
||||
UpdateOutputs(graph, outputs, *input_tensors);
|
||||
// update output abstract of dynamic op to op_run_info
|
||||
if (op_run_info->is_dynamic_shape) {
|
||||
UpdateOutputAbstract(graph, op_run_info);
|
||||
}
|
||||
RunOpMemoryClear(graph.get());
|
||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
||||
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!";
|
||||
}
|
||||
|
||||
void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
|
@ -750,7 +755,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
|
|||
|
||||
// Build and run current single op
|
||||
VectorRef op_outputs;
|
||||
RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||
input_tensor_info.input_tensors_mask);
|
||||
|
||||
// Handle inputs and outputs of current op
|
||||
|
|
|
@ -59,9 +59,8 @@ class AscendSession : public SessionBasic {
|
|||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) override;
|
||||
|
||||
|
|
|
@ -170,11 +170,12 @@ void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::T
|
|||
}
|
||||
}
|
||||
|
||||
void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
|
|
|
@ -41,9 +41,8 @@ class CPUSession : public SessionBasic {
|
|||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
|
||||
private:
|
||||
void SetKernelInfo(const KernelGraph *kernel_graph);
|
||||
|
|
|
@ -130,7 +130,7 @@ void RunGraphTask::Run() {
|
|||
|
||||
void RunOpTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_);
|
||||
session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
|
||||
}
|
||||
|
||||
void RunOpsInGraphTask::Run() {
|
||||
|
|
|
@ -403,13 +403,14 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
|
||||
void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// run op
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// Remove NopOp from execution graph
|
||||
|
@ -420,6 +421,10 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_
|
|||
Execute(kernel_graph);
|
||||
// Fetch outputs
|
||||
UpdateOutputs(kernel_graph, outputs, *input_tensors);
|
||||
// update output abstract of dynamic op to op_run_info
|
||||
if (op_run_info->is_dynamic_shape) {
|
||||
UpdateOutputAbstract(kernel_graph, op_run_info);
|
||||
}
|
||||
RunOpClearMemory(kernel_graph.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -39,9 +39,8 @@ class GPUSession : public SessionBasic {
|
|||
void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
|
||||
private:
|
||||
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -1142,6 +1142,19 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
}
|
||||
}
|
||||
|
||||
void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
OpRunInfo *op_run_info) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
const auto &kernels = kernel_graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
|
||||
op_run_info->abstract = kernel->abstract();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs) {
|
||||
auto graph = GetGraph(graph_id);
|
||||
|
|
|
@ -153,7 +153,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
|
@ -167,6 +167,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) const;
|
||||
void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const;
|
||||
void Reorder(std::vector<CNodePtr> *node_list);
|
||||
void Summary(KernelGraph *graph);
|
||||
// create graph output for RunOp
|
||||
|
|
|
@ -70,10 +70,12 @@ const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_ca
|
|||
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
|
||||
const std::set<std::string> ignore_judge_dynamic_cell = {
|
||||
"Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal",
|
||||
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"};
|
||||
"Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"};
|
||||
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
|
||||
parse::NAMED_PRIMITIVE_NAMECONSTANT,
|
||||
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
|
||||
const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup",
|
||||
"Transpose"};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -467,6 +467,11 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
|||
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®);
|
||||
if (op_run_info->is_dynamic_shape &&
|
||||
dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name;
|
||||
reg_exist = false;
|
||||
}
|
||||
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
|
||||
reg_exist = false;
|
||||
}
|
||||
|
@ -594,6 +599,7 @@ py::tuple RunOp(const py::args &args) {
|
|||
}
|
||||
|
||||
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
|
||||
return RunOpWithInitBackendPolicy(op_exec_info);
|
||||
}
|
||||
|
@ -604,58 +610,27 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|||
op_exec_info->inputs_mask = op_masks;
|
||||
// get output abstract info
|
||||
bool is_find = false;
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
|
||||
auto abs_list = prim_abs_list_[prim->id()];
|
||||
MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
||||
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
||||
MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name;
|
||||
op_exec_info->abstract = abs_list[args_spec_list].abs;
|
||||
op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape;
|
||||
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
|
||||
is_find = true;
|
||||
}
|
||||
}
|
||||
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) {
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
|
||||
}
|
||||
// get output dynamic shape info
|
||||
auto abstract = op_exec_info->abstract;
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto shape = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_info = shape->ToString();
|
||||
if (shape_info.find("-1") != string::npos) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
}
|
||||
}
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
}
|
||||
GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find);
|
||||
MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
|
||||
// infer output value for const prim
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
if (op_exec_info->abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
|
||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
||||
if (!output["value"].is_none()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = output["value"];
|
||||
return value_ret;
|
||||
}
|
||||
if (op_exec_info->py_primitive->is_const_prim()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = "";
|
||||
return value_ret;
|
||||
}
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
|
||||
if (!output["value"].is_none()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = output["value"];
|
||||
return value_ret;
|
||||
}
|
||||
if (prim->is_const_prim()) {
|
||||
py::tuple value_ret(1);
|
||||
value_ret[0] = "";
|
||||
return value_ret;
|
||||
}
|
||||
// add output abstract info into cache
|
||||
if (!is_find) {
|
||||
if (!is_find && !op_exec_info->is_dynamic_shape) {
|
||||
// const_value need infer every step
|
||||
auto &out = prim_abs_list_[prim->id()];
|
||||
out[args_spec_list].abs = op_exec_info->abstract;
|
||||
out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape;
|
||||
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
|
||||
}
|
||||
|
@ -666,8 +641,13 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
|
|||
MS_LOG(DEBUG) << "Output size is 1";
|
||||
out_real = result[0];
|
||||
}
|
||||
// update output abstract for cnode
|
||||
if (cnode != nullptr) {
|
||||
cnode->set_abstract(op_exec_info->abstract);
|
||||
}
|
||||
std::string obj_id = GetId(out_real);
|
||||
node_abs_map_[obj_id] = op_exec_info->abstract;
|
||||
// save info for building grad graph
|
||||
SaveOutputNodeMap(obj_id, out_real, cnode);
|
||||
SaveAllResult(op_exec_info, cnode, out_real);
|
||||
// Update the abstract and device address of value node with tensor in grad graph
|
||||
|
@ -784,6 +764,49 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
return cnode;
|
||||
}
|
||||
|
||||
void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
||||
const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) {
|
||||
MS_EXCEPTION_IF_NULL(is_find);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
*is_find = false;
|
||||
auto op_name = op_exec_info->op_name;
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
|
||||
auto abs_list = prim_abs_list_[prim->id()];
|
||||
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
|
||||
if (abs_list.find(args_spec_list) != abs_list.end()) {
|
||||
MS_LOG(DEBUG) << "Match prim ok " << op_name;
|
||||
op_exec_info->abstract = abs_list[args_spec_list].abs;
|
||||
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
|
||||
*is_find = true;
|
||||
}
|
||||
}
|
||||
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
|
||||
}
|
||||
}
|
||||
// get output dynamic shape info
|
||||
auto py_abstract = op_exec_info->abstract;
|
||||
MS_EXCEPTION_IF_NULL(py_abstract);
|
||||
auto py_shape = py_abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(py_shape);
|
||||
auto py_shape_info = py_shape->ToString();
|
||||
if (py_shape_info.find("-1") != string::npos) {
|
||||
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(c_abstract);
|
||||
auto c_shape = c_abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(c_shape);
|
||||
auto c_shape_info = c_shape->ToString();
|
||||
MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info;
|
||||
if (c_shape_info.find("-1") != string::npos) {
|
||||
op_exec_info->is_dynamic_shape = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
|
||||
size_t index) {
|
||||
py::tuple cast_args(3);
|
||||
|
@ -1326,6 +1349,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
|||
op_exec_info->next_input_index};
|
||||
VectorRef outputs;
|
||||
session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
|
||||
if (op_exec_info->is_dynamic_shape) {
|
||||
op_exec_info->abstract = op_run_info.abstract;
|
||||
}
|
||||
auto result = BaseRefToPyData(outputs);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
|
|
|
@ -129,6 +129,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *is_find);
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
|
||||
// replace for grad graph
|
||||
|
|
|
@ -577,6 +577,7 @@ void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph *grap
|
|||
}
|
||||
|
||||
bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
bool ret = false;
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -336,6 +336,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
bool ret = true;
|
||||
|
@ -360,7 +361,12 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
|
|||
|
||||
ret = RunOneStep(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
if (graph->is_dynamic_shape()) {
|
||||
// run dynamic shape graph in pynative
|
||||
ret = RunOpLaunchKernelDynamic(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
}
|
||||
}
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
|
@ -674,6 +680,42 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
// akg kernel do not support dynamic shape by now.
|
||||
device::DynamicKernelPtr dynamic_kernel = nullptr;
|
||||
kernel::GpuKernel *gpu_kernel = nullptr;
|
||||
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) != KernelType::AKG_KERNEL) {
|
||||
gpu_kernel = dynamic_cast<kernel::GpuKernel *>(kernel_mod);
|
||||
dynamic_kernel = gpu_kernel->DynamicKernel();
|
||||
}
|
||||
// pre-processing for dynamic shape kernel
|
||||
if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
|
||||
dynamic_kernel->InferShape();
|
||||
dynamic_kernel->UpdateArgs();
|
||||
}
|
||||
// alloc kernel res
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
}
|
||||
if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) {
|
||||
gpu_kernel->PostExecute();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
|
||||
const AddressPtrList &workspace, const AddressPtrList &outputs) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
|
|
|
@ -73,6 +73,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
bool SearchMemSwapScheme(const session::KernelGraph *graph);
|
||||
bool RefineMemSwapScheme(const session::KernelGraph *graph);
|
||||
bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false);
|
||||
bool RunOpLaunchKernelDynamic(const session::KernelGraph *graph);
|
||||
void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
|
||||
const AddressPtrList &workspace, const AddressPtrList &outputs);
|
||||
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock);
|
||||
|
|
|
@ -875,7 +875,7 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList
|
|||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
||||
auto &kernels = graph.execution_order();
|
||||
const auto &kernels = graph.execution_order();
|
||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
||||
if (iter != graph_dynamic_kernel_map_.end()) {
|
||||
|
|
Loading…
Reference in New Issue