Move BuildOp into RunOp
This commit is contained in:
parent
1ffecf1874
commit
d44dd4f786
|
@ -691,22 +691,27 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||||
|
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||||
|
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||||
|
|
||||||
auto graph = run_op_graphs_[graph_info];
|
auto graph = run_op_graphs_[graph_info];
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
|
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
|
||||||
// malloc mem
|
// malloc mem
|
||||||
RunOpMemoryAlloc(input_tensors, graph.get());
|
RunOpMemoryAlloc(*input_tensors, graph.get());
|
||||||
// Build dynamic kernel
|
// Build dynamic kernel
|
||||||
if (op_run_info.is_dynamic_shape) {
|
if (op_run_info.is_dynamic_shape) {
|
||||||
BuildDynamicKernel(graph);
|
BuildDynamicKernel(graph);
|
||||||
}
|
}
|
||||||
// load input data to device
|
// load input data to device
|
||||||
LoadInputData(graph, input_tensors);
|
LoadInputData(graph, *input_tensors);
|
||||||
// run op
|
// run op
|
||||||
Execute(graph, false);
|
Execute(graph, false);
|
||||||
// get output
|
// get output
|
||||||
UpdateOutputs(graph, outputs, input_tensors);
|
UpdateOutputs(graph, outputs, *input_tensors);
|
||||||
RunOpMemoryClear(graph.get());
|
RunOpMemoryClear(graph.get());
|
||||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
||||||
}
|
}
|
||||||
|
@ -736,7 +741,8 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
|
||||||
// Build and run current single op
|
// Build and run current single op
|
||||||
BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
|
||||||
VectorRef op_outputs;
|
VectorRef op_outputs;
|
||||||
RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs);
|
RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||||
|
input_tensor_info.input_tensors_mask);
|
||||||
|
|
||||||
// Handle inputs and outputs of current op
|
// Handle inputs and outputs of current op
|
||||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
|
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
|
||||||
|
|
|
@ -60,7 +60,8 @@ class AscendSession : public SessionBasic {
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const std::vector<int64_t> &tensors_mask) override;
|
const std::vector<int64_t> &tensors_mask) override;
|
||||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) override;
|
||||||
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
VectorRef *outputs) override;
|
VectorRef *outputs) override;
|
||||||
|
|
||||||
|
|
|
@ -125,14 +125,9 @@ void RunGraphTask::Run() {
|
||||||
ExecutorManager::Instance().OnRunGraphFinished();
|
ExecutorManager::Instance().OnRunGraphFinished();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildOpTask::Run() {
|
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
|
||||||
session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RunOpTask::Run() {
|
void RunOpTask::Run() {
|
||||||
MS_EXCEPTION_IF_NULL(session_);
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
|
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RunOpsInGraphTask::Run() {
|
void RunOpsInGraphTask::Run() {
|
||||||
|
@ -340,25 +335,16 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
||||||
task_cond_var_.notify_all();
|
task_cond_var_.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask) {
|
|
||||||
auto task = std::make_shared<BuildOpTask>();
|
|
||||||
task->session_ = session;
|
|
||||||
task->op_run_info_ = op_run_info;
|
|
||||||
task->graph_info_ = graph_info;
|
|
||||||
task->input_tensors_ = input_tensors;
|
|
||||||
task->tensors_mask_ = tensors_mask;
|
|
||||||
SyncRunTask(task);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) {
|
||||||
auto task = std::make_shared<RunOpTask>();
|
auto task = std::make_shared<RunOpTask>();
|
||||||
task->session_ = session;
|
task->session_ = session;
|
||||||
task->op_run_info_ = op_run_info;
|
task->op_run_info_ = op_run_info;
|
||||||
task->graph_info_ = graph_info;
|
task->graph_info_ = graph_info;
|
||||||
task->input_tensors_ = input_tensors;
|
task->input_tensors_ = input_tensors;
|
||||||
for (auto &tensor : input_tensors) {
|
task->tensors_mask_ = tensors_mask;
|
||||||
|
for (auto &tensor : *input_tensors) {
|
||||||
if (tensor->NeedWait()) {
|
if (tensor->NeedWait()) {
|
||||||
tensor->Wait();
|
tensor->Wait();
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,17 +110,6 @@ class RunOpsInGraphTask : public Task {
|
||||||
GraphId graph_id_{0};
|
GraphId graph_id_{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
class BuildOpTask : public Task {
|
|
||||||
public:
|
|
||||||
BuildOpTask() { type_ = kBuildOp; }
|
|
||||||
~BuildOpTask() override = default;
|
|
||||||
void Run() override;
|
|
||||||
OpRunInfo *op_run_info_{nullptr};
|
|
||||||
GraphInfo graph_info_;
|
|
||||||
std::vector<tensor::TensorPtr> input_tensors_;
|
|
||||||
std::vector<int64_t> tensors_mask_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class RunOpTask : public Task {
|
class RunOpTask : public Task {
|
||||||
public:
|
public:
|
||||||
RunOpTask() { type_ = kRunOp; }
|
RunOpTask() { type_ = kRunOp; }
|
||||||
|
@ -128,8 +117,9 @@ class RunOpTask : public Task {
|
||||||
void Run() override;
|
void Run() override;
|
||||||
OpRunInfo *op_run_info_{nullptr};
|
OpRunInfo *op_run_info_{nullptr};
|
||||||
GraphInfo graph_info_;
|
GraphInfo graph_info_;
|
||||||
std::vector<tensor::TensorPtr> input_tensors_;
|
std::vector<tensor::TensorPtr> *input_tensors_;
|
||||||
VectorRef outputs_;
|
VectorRef outputs_;
|
||||||
|
std::vector<int64_t> tensors_mask_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateCommGroupTask : public Task {
|
class CreateCommGroupTask : public Task {
|
||||||
|
@ -170,10 +160,9 @@ class Executor {
|
||||||
VectorRef *outputs);
|
VectorRef *outputs);
|
||||||
void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
VectorRef *outputs);
|
VectorRef *outputs);
|
||||||
void BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
|
|
||||||
void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask);
|
||||||
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
VectorRef *outputs);
|
VectorRef *outputs);
|
||||||
void OnRunGraphFinished();
|
void OnRunGraphFinished();
|
||||||
|
|
|
@ -398,17 +398,22 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||||
|
BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||||
|
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||||
|
|
||||||
auto kernel_graph = run_op_graphs_[graph_info];
|
auto kernel_graph = run_op_graphs_[graph_info];
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
// Remove NopOp from execution graph
|
// Remove NopOp from execution graph
|
||||||
opt::RemoveNopNode(kernel_graph.get());
|
opt::RemoveNopNode(kernel_graph.get());
|
||||||
RunOpAllocateMemory(input_tensors, kernel_graph.get());
|
RunOpAllocateMemory(*input_tensors, kernel_graph.get());
|
||||||
// Execute the computation
|
// Execute the computation
|
||||||
LoadInputData(kernel_graph, input_tensors);
|
LoadInputData(kernel_graph, *input_tensors);
|
||||||
Execute(kernel_graph);
|
Execute(kernel_graph);
|
||||||
// Fetch outputs
|
// Fetch outputs
|
||||||
UpdateOutputs(kernel_graph, outputs, input_tensors);
|
UpdateOutputs(kernel_graph, outputs, *input_tensors);
|
||||||
RunOpClearMemory(kernel_graph.get());
|
RunOpClearMemory(kernel_graph.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,8 @@ class GPUSession : public SessionBasic {
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const std::vector<int64_t> &tensors_mask) override;
|
const std::vector<int64_t> &tensors_mask) override;
|
||||||
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||||
|
|
|
@ -1593,17 +1593,11 @@ void SessionBasic::BuildGraph(GraphId graph_id) {
|
||||||
executor_->BuildGraph(shared_from_this(), graph_id);
|
executor_->BuildGraph(shared_from_this(), graph_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SessionBasic::BuildOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
|
||||||
const std::vector<int64_t> &tensors_mask) {
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
executor_->BuildOp(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) {
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
MS_EXCEPTION_IF_NULL(executor_);
|
||||||
executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
|
executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs, tensors_mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
|
@ -1623,6 +1617,22 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
|
||||||
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
|
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
|
||||||
|
std::vector<tensor::TensorPtr> *input_tensors) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||||
|
if (input_tensors->size() != tensors_mask.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
|
||||||
|
<< tensors_mask.size();
|
||||||
|
}
|
||||||
|
std::vector<tensor::TensorPtr> new_input_tensors;
|
||||||
|
for (size_t index = 0; index < tensors_mask.size(); ++index) {
|
||||||
|
if (tensors_mask[index] != kValueNodeTensorMask) {
|
||||||
|
new_input_tensors.emplace_back(input_tensors->at(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*input_tensors = new_input_tensors;
|
||||||
|
}
|
||||||
|
|
||||||
void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs) {
|
void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs) {
|
||||||
bool is_dynamic = false;
|
bool is_dynamic = false;
|
||||||
for (const auto &graph : all_graphs) {
|
for (const auto &graph : all_graphs) {
|
||||||
|
|
|
@ -76,9 +76,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
void BuildGraph(GraphId graphId);
|
void BuildGraph(GraphId graphId);
|
||||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||||
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||||
void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
|
void RunOp(OpRunInfo *, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
const std::vector<int64_t> &tensors_mask);
|
const std::vector<int64_t> &tensors_mask);
|
||||||
void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
|
|
||||||
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||||
|
|
||||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||||
|
@ -137,7 +136,6 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
friend class CompileGraphTask;
|
friend class CompileGraphTask;
|
||||||
friend class BuildGraphTask;
|
friend class BuildGraphTask;
|
||||||
friend class RunGraphTask;
|
friend class RunGraphTask;
|
||||||
friend class BuildOpTask;
|
|
||||||
friend class RunOpTask;
|
friend class RunOpTask;
|
||||||
friend class RunOpsInGraphTask;
|
friend class RunOpsInGraphTask;
|
||||||
virtual bool IsSupportSummary() { return true; }
|
virtual bool IsSupportSummary() { return true; }
|
||||||
|
@ -156,7 +154,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
const std::vector<int64_t> &tensors_mask) {}
|
const std::vector<int64_t> &tensors_mask) {}
|
||||||
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {}
|
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||||
|
const std::vector<int64_t> &tensors_mask) {}
|
||||||
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||||
VectorRef *outputs) {}
|
VectorRef *outputs) {}
|
||||||
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||||
|
@ -165,6 +164,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
|
|
||||||
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||||
|
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors);
|
||||||
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors) const;
|
const std::vector<tensor::TensorPtr> &input_tensors) const;
|
||||||
void Reorder(std::vector<CNodePtr> *node_list);
|
void Reorder(std::vector<CNodePtr> *node_list);
|
||||||
|
|
|
@ -471,21 +471,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t>
|
||||||
op_prim->EndRecordAddAttr();
|
op_prim->EndRecordAddAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors) {
|
|
||||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
|
||||||
if (input_tensors->size() != tensors_mask.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
|
|
||||||
<< tensors_mask.size();
|
|
||||||
}
|
|
||||||
std::vector<tensor::TensorPtr> new_input_tensors;
|
|
||||||
for (size_t index = 0; index < tensors_mask.size(); ++index) {
|
|
||||||
if (tensors_mask[index] != kValueNodeTensorMask) {
|
|
||||||
new_input_tensors.emplace_back(input_tensors->at(index));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*input_tensors = new_input_tensors;
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
||||||
if (utils::isa<VectorRef>(base_ref)) {
|
if (utils::isa<VectorRef>(base_ref)) {
|
||||||
auto ref_list = utils::cast<VectorRef>(base_ref);
|
auto ref_list = utils::cast<VectorRef>(base_ref);
|
||||||
|
@ -1301,10 +1286,8 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
|
||||||
op_exec_info->is_mixed_precision_cast,
|
op_exec_info->is_mixed_precision_cast,
|
||||||
op_exec_info->next_op_name,
|
op_exec_info->next_op_name,
|
||||||
op_exec_info->next_input_index};
|
op_exec_info->next_input_index};
|
||||||
session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask);
|
|
||||||
EraseValueNodeTensor(tensors_mask, &input_tensors);
|
|
||||||
VectorRef outputs;
|
VectorRef outputs;
|
||||||
session->RunOp(&op_run_info, graph_info, input_tensors, &outputs);
|
session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask);
|
||||||
auto result = BaseRefToPyData(outputs);
|
auto result = BaseRefToPyData(outputs);
|
||||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
|
||||||
*status = PYNATIVE_SUCCESS;
|
*status = PYNATIVE_SUCCESS;
|
||||||
|
|
Loading…
Reference in New Issue