Move BuildOp into RunOp
This commit is contained in:
parent
1ffecf1874
commit
d44dd4f786
|
@ -691,22 +691,27 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
|
|||
}
|
||||
|
||||
void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
|
||||
// malloc mem
|
||||
RunOpMemoryAlloc(input_tensors, graph.get());
|
||||
RunOpMemoryAlloc(*input_tensors, graph.get());
|
||||
// Build dynamic kernel
|
||||
if (op_run_info.is_dynamic_shape) {
|
||||
BuildDynamicKernel(graph);
|
||||
}
|
||||
// load input data to device
|
||||
LoadInputData(graph, input_tensors);
|
||||
LoadInputData(graph, *input_tensors);
|
||||
// run op
|
||||
Execute(graph, false);
|
||||
// get output
|
||||
UpdateOutputs(graph, outputs, input_tensors);
|
||||
UpdateOutputs(graph, outputs, *input_tensors);
|
||||
RunOpMemoryClear(graph.get());
|
||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
||||
}
|
||||
|
@ -736,7 +741,8 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
|
|||
// Build and run current single op
|
||||
BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
||||
VectorRef op_outputs;
|
||||
RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs);
|
||||
RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||
input_tensor_info.input_tensors_mask);
|
||||
|
||||
// Handle inputs and outputs of current op
|
||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
|
||||
|
|
|
@ -60,7 +60,8 @@ class AscendSession : public SessionBasic {
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) override;
|
||||
|
||||
|
|
|
@ -125,14 +125,9 @@ void RunGraphTask::Run() {
|
|||
ExecutorManager::Instance().OnRunGraphFinished();
|
||||
}
|
||||
|
||||
void BuildOpTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
|
||||
}
|
||||
|
||||
void RunOpTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
|
||||
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_);
|
||||
}
|
||||
|
||||
void RunOpsInGraphTask::Run() {
|
||||
|
@ -340,25 +335,16 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|||
task_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask) {
|
||||
auto task = std::make_shared<BuildOpTask>();
|
||||
task->session_ = session;
|
||||
task->op_run_info_ = op_run_info;
|
||||
task->graph_info_ = graph_info;
|
||||
task->input_tensors_ = input_tensors;
|
||||
task->tensors_mask_ = tensors_mask;
|
||||
SyncRunTask(task);
|
||||
}
|
||||
|
||||
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
auto task = std::make_shared<RunOpTask>();
|
||||
task->session_ = session;
|
||||
task->op_run_info_ = op_run_info;
|
||||
task->graph_info_ = graph_info;
|
||||
task->input_tensors_ = input_tensors;
|
||||
for (auto &tensor : input_tensors) {
|
||||
task->tensors_mask_ = tensors_mask;
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWait()) {
|
||||
tensor->Wait();
|
||||
}
|
||||
|
|
|
@ -110,17 +110,6 @@ class RunOpsInGraphTask : public Task {
|
|||
GraphId graph_id_{0};
|
||||
};
|
||||
|
||||
class BuildOpTask : public Task {
|
||||
public:
|
||||
BuildOpTask() { type_ = kBuildOp; }
|
||||
~BuildOpTask() override = default;
|
||||
void Run() override;
|
||||
OpRunInfo *op_run_info_{nullptr};
|
||||
GraphInfo graph_info_;
|
||||
std::vector<tensor::TensorPtr> input_tensors_;
|
||||
std::vector<int64_t> tensors_mask_;
|
||||
};
|
||||
|
||||
class RunOpTask : public Task {
|
||||
public:
|
||||
RunOpTask() { type_ = kRunOp; }
|
||||
|
@ -128,8 +117,9 @@ class RunOpTask : public Task {
|
|||
void Run() override;
|
||||
OpRunInfo *op_run_info_{nullptr};
|
||||
GraphInfo graph_info_;
|
||||
std::vector<tensor::TensorPtr> input_tensors_;
|
||||
std::vector<tensor::TensorPtr> *input_tensors_;
|
||||
VectorRef outputs_;
|
||||
std::vector<int64_t> tensors_mask_;
|
||||
};
|
||||
|
||||
class CreateCommGroupTask : public Task {
|
||||
|
@ -170,10 +160,9 @@ class Executor {
|
|||
VectorRef *outputs);
|
||||
void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs);
|
||||
void BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
|
||||
void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs);
|
||||
void OnRunGraphFinished();
|
||||
|
|
|
@ -398,17 +398,22 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
}
|
||||
|
||||
void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// Remove NopOp from execution graph
|
||||
opt::RemoveNopNode(kernel_graph.get());
|
||||
RunOpAllocateMemory(input_tensors, kernel_graph.get());
|
||||
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
|
||||
// Execute the computation
|
||||
LoadInputData(kernel_graph, input_tensors);
|
||||
LoadInputData(kernel_graph, *input_tensors);
|
||||
Execute(kernel_graph);
|
||||
// Fetch outputs
|
||||
UpdateOutputs(kernel_graph, outputs, input_tensors);
|
||||
UpdateOutputs(kernel_graph, outputs, *input_tensors);
|
||||
RunOpClearMemory(kernel_graph.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,8 @@ class GPUSession : public SessionBasic {
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) override;
|
||||
|
||||
private:
|
||||
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -1593,17 +1593,11 @@ void SessionBasic::BuildGraph(GraphId graph_id) {
|
|||
executor_->BuildGraph(shared_from_this(), graph_id);
|
||||
}
|
||||
|
||||
void SessionBasic::BuildOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->BuildOp(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
|
||||
}
|
||||
|
||||
void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
|
||||
executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs, tensors_mask);
|
||||
}
|
||||
|
||||
void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
|
@ -1623,6 +1617,22 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
|
|||
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
|
||||
}
|
||||
|
||||
void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
|
||||
std::vector<tensor::TensorPtr> *input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
if (input_tensors->size() != tensors_mask.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
|
||||
<< tensors_mask.size();
|
||||
}
|
||||
std::vector<tensor::TensorPtr> new_input_tensors;
|
||||
for (size_t index = 0; index < tensors_mask.size(); ++index) {
|
||||
if (tensors_mask[index] != kValueNodeTensorMask) {
|
||||
new_input_tensors.emplace_back(input_tensors->at(index));
|
||||
}
|
||||
}
|
||||
*input_tensors = new_input_tensors;
|
||||
}
|
||||
|
||||
void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs) {
|
||||
bool is_dynamic = false;
|
||||
for (const auto &graph : all_graphs) {
|
||||
|
|
|
@ -76,9 +76,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void BuildGraph(GraphId graphId);
|
||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
|
||||
void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask);
|
||||
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
|
||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
|
@ -137,7 +136,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
friend class CompileGraphTask;
|
||||
friend class BuildGraphTask;
|
||||
friend class RunGraphTask;
|
||||
friend class BuildOpTask;
|
||||
friend class RunOpTask;
|
||||
friend class RunOpsInGraphTask;
|
||||
virtual bool IsSupportSummary() { return true; }
|
||||
|
@ -156,7 +154,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {}
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {}
|
||||
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {}
|
||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||
|
@ -165,6 +164,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
|
||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
|
||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) const;
|
||||
void Reorder(std::vector<CNodePtr> *node_list);
|
||||
|
|
|
@ -471,21 +471,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
|||
op_prim->EndRecordAddAttr();
|
||||
}
|
||||
|
||||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
if (input_tensors->size() != tensors_mask.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
|
||||
<< tensors_mask.size();
|
||||
}
|
||||
std::vector<tensor::TensorPtr> new_input_tensors;
|
||||
for (size_t index = 0; index < tensors_mask.size(); ++index) {
|
||||
if (tensors_mask[index] != kValueNodeTensorMask) {
|
||||
new_input_tensors.emplace_back(input_tensors->at(index));
|
||||
}
|
||||
}
|
||||
*input_tensors = new_input_tensors;
|
||||
}
|
||||
|
||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
||||
if (utils::isa<VectorRef>(base_ref)) {
|
||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||
|
@ -1301,10 +1286,8 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
|||
op_exec_info->is_mixed_precision_cast,
|
||||
op_exec_info->next_op_name,
|
||||
op_exec_info->next_input_index};
|
||||
session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
||||
VectorRef outputs;
|
||||
session->RunOp(&op_run_info, graph_info, input_tensors, &outputs);
|
||||
session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
|
||||
auto result = BaseRefToPyData(outputs);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
|
|
Loading…
Reference in New Issue