Run ops one by one in pynative bp graph

This commit is contained in:
HulkTang 2020-11-26 15:19:17 +08:00
parent d0d5a8b878
commit c36b477568
19 changed files with 672 additions and 227 deletions

View File

@ -327,6 +327,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());

View File

@ -22,6 +22,7 @@
#include <list>
#include "base/core_ops.h"
#include "base/base_ref_utils.h"
#include "ir/tensor.h"
#include "ir/anf.h"
#include "common/trans.h"
@ -123,6 +124,284 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()});
root_graph->set_output(make_tuple);
}
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
auto &node = node_output_pair.first;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
<< "]";
// if node is a value node, no need sync addr from device to host
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph->inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
}
(*output_indexes)[node_output_pair] = indexes;
BaseRef output_placeholder = std::make_shared<BaseRef>();
return output_placeholder;
}
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
MS_LOG(INFO) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
// special handle for maketuple
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
auto cnode = item_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
std::vector<size_t> cur_index = indexes;
cur_index.emplace_back(i - 1);
auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
ret.push_back(out);
}
return ret;
}
// if is graph return nothing ,the function should return a null anylist
size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
if (size == 0) {
return VectorRef();
}
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
}
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs, std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_indexes);
auto anf_outputs = kernel_graph->outputs();
size_t index = 0;
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Create node output placeholder[" << item->DebugString() << "]";
std::vector<size_t> indexes{index++};
outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
}
}
void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
const auto &node = kernel_with_index.first;
if (node->isa<CNode>()) {
(*ref_count)[kernel_with_index] += 1;
}
}
}
}
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
std::map<AnfNodePtr, size_t> *parameter_index) {
size_t index = 0;
for (const auto &input_node : graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
for (const auto &param : params) {
if (index >= inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
<< ", input size: " << inputs.size();
}
const auto &input = inputs[index];
// Check shape of input and parameter
const auto &input_shape = input->shape();
const auto &param_shape = AnfAlgo::GetOutputInferShape(param, 0);
if (input_shape.size() != param_shape.size()) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
for (size_t i = 0; i < input_shape.size(); i += 1) {
if (input_shape[i] < 0 || static_cast<size_t>(input_shape[i]) != param_shape[i]) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
}
parameter_index->emplace(param, index++);
}
}
}
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
auto value_node = input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value_node);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (kernel_with_index.second >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index " << kernel_with_index.second << "is out of value tuple range";
}
auto tensor_value = value_tuple->value()[kernel_with_index.second];
if (tensor_value->isa<tensor::Tensor>()) {
tensor = tensor_value->cast<tensor::TensorPtr>();
}
} else if (value->isa<tensor::Tensor>()) {
if (kernel_with_index.second != 0) {
MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << kernel_with_index.second;
}
tensor = GetValueNode<TensorPtr>(value_node);
}
} else if (real_input->isa<Parameter>()) {
const auto &iter = parameter_index.find(real_input);
if (iter == parameter_index.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, node = " << cnode->DebugString();
}
const size_t index = iter->second;
if (index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = "
<< cnode->DebugString() << "input tensor size = " << graph_inputs.size();
}
tensor = graph_inputs[index];
} else if (real_input->isa<CNode>()) {
const auto &iter = op_output.find(kernel_with_index);
if (iter == op_output.end()) {
MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << real_input->DebugString();
}
tensor = iter->second;
input_tensor_info->input_kernel.insert(kernel_with_index);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
for (auto &kernel_with_index : input_kernel) {
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
ref_iter->second -= 1;
if (ref_iter->second != 0) {
continue;
}
ref_count->erase(ref_iter);
auto output_iter = op_output_map->find(kernel_with_index);
if (output_iter == op_output_map->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
op_output_map->erase(output_iter);
}
}
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<size_t>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() != op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = output_indexes.find(kernel_with_index);
if (iter == output_indexes.end()) {
continue;
}
const std::vector<size_t> &ref_indexes = iter->second;
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
while (n != ref_indexes.size() - 1) {
size_t index = ref_indexes.at(n++);
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
}
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(run_info);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
run_info->primitive = primitive;
run_info->op_name = primitive->name();
if (cnode->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
}
run_info->abstract = cnode->abstract();
}
GraphInfo GetSingleOpGraphInfo(const PrimitivePtr &prim, const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(prim);
GraphInfo graph_info;
// get input tensor info
for (const auto &tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_shape = tensor->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(tensor->data_type()) + "_");
if (tensor->device_address() != nullptr) {
const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
(void)graph_info.append(std::to_string(type_id) + "_");
const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
(void)graph_info.append(format + "_");
}
}
// get attr info
const auto &attr_map = prim->evaluate_added_attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
graph_info.append(prim->id());
return graph_info;
}
} // namespace
void AscendSession::Init(uint32_t device_id) {
@ -417,7 +696,7 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
// malloc mem
RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
RunOpMemoryAlloc(input_tensors, graph.get());
// Build dynamic kernel
if (op_run_info.is_dynamic_shape) {
BuildDynamicKernel(graph);
@ -432,6 +711,39 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
}
void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_LOG(INFO) << "Start";
auto kernel_graph = GetGraph(graph_id);
std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
std::map<KernelWithIndex, std::vector<size_t>> output_indexes;
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
// Generate input tensors, tensor masks and input kernel with index
InputTensorInfo input_tensor_info;
GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
// Get OpRunInfo and GraphInfo
OpRunInfo run_info;
GetSingleOpRunInfo(kernel, &run_info);
GraphInfo graph_info = GetSingleOpGraphInfo(run_info.primitive, input_tensor_info.input_tensors);
// Build and run current single op
BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
VectorRef op_outputs;
RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs);
// Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs);
}
}
// compile graph steps
void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
MS_LOG(INFO) << "Start!";
@ -591,15 +903,14 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Finish!";
}
void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Start memory alloc!";
MS_EXCEPTION_IF_NULL(kernel_graph);
opt::RemoveNopNode(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
MS_LOG(INFO) << "Finish!";
}

View File

@ -35,6 +35,11 @@
namespace mindspore {
namespace session {
enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 };
struct InputTensorInfo {
std::vector<tensor::TensorPtr> input_tensors;
std::vector<int64_t> input_tensors_mask;
std::set<KernelWithIndex> input_kernel;
};
class AscendSession : public SessionBasic {
public:
@ -56,6 +61,8 @@ class AscendSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override;
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) override;
private:
// compile child graph when session have multiple child graphs
@ -72,8 +79,7 @@ class AscendSession : public SessionBasic {
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;

View File

@ -16,8 +16,9 @@
#include "backend/session/executor.h"
#include <algorithm>
#include <exception>
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"
@ -134,6 +135,11 @@ void RunOpTask::Run() {
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_);
}
void RunOpsInGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
}
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
@ -361,6 +367,18 @@ void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const Gr
*outputs = task->outputs_;
}
void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);
MS_EXCEPTION_IF_NULL(outputs);
auto task = std::make_shared<RunOpsInGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
SyncRunTask(task);
*outputs = task->outputs_;
}
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
auto task = std::make_shared<CreateCommGroupTask>();
task->group_name_ = group_name;

View File

@ -16,22 +16,23 @@
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include <list>
#include <queue>
#include <map>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/contract.h"
#include "utils/comm_manager.h"
#include "utils/contract.h"
namespace mindspore {
namespace session {
@ -45,7 +46,8 @@ enum TaskType {
kRunGraph,
kRunOp,
kCreateCommGroup,
kDestroyCommGroup
kDestroyCommGroup,
kRunOpsInGraph
};
class Task {
@ -98,6 +100,16 @@ class RunGraphTask : public Task {
std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
};
class RunOpsInGraphTask : public Task {
public:
RunOpsInGraphTask() { type_ = kRunOpsInGraph; }
~RunOpsInGraphTask() override = default;
void Run() override;
std::vector<tensor::TensorPtr> input_tensors_;
VectorRef outputs_;
GraphId graph_id_{0};
};
class BuildOpTask : public Task {
public:
BuildOpTask() { type_ = kBuildOp; }
@ -162,6 +174,8 @@ class Executor {
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);
void OnRunGraphFinished();
bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks);
bool DestroyCommGroup(const std::string &group_name);

View File

@ -198,13 +198,12 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {
runtime_instance->AssignMemory(kernel_graph);
}
void GPUSession::RunOpAllocateMemory(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph);
}
void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const {
@ -351,6 +350,8 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
VectorRef *outputs) {
auto &kernel_graph = graphs_[graph_id];
MS_LOG(INFO) << "RunGraph graph_id: " << graph_id;
// In pynative mode, device addresses of tensors in value nodes change.
SyncValueNodeDeviceAddr(kernel_graph);
// Load input data from user input
LoadInputData(kernel_graph, inputs);
PreIterationDbg(kernel_graph);
@ -366,6 +367,8 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor:
Execute(kernel_graph);
}
PostLoadTensor(kernel_graph);
// In pynative mode, device addresses of tensors in value nodes need be clean.
CleanValueNodeDeviceAddr(kernel_graph);
// Summary
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -400,7 +403,7 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_
MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove NopOp from execution graph
opt::RemoveNopNode(kernel_graph.get());
RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get());
RunOpAllocateMemory(input_tensors, kernel_graph.get());
// Execute the computation
LoadInputData(kernel_graph, input_tensors);
Execute(kernel_graph);
@ -471,6 +474,28 @@ void GPUSession::PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph
TensorLoader *tensor_loader = debug_services->tensor_loader();
tensor_loader->EmptyPrevTensor();
}
void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get());
}
void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
return;
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get());
}
} // namespace gpu
} // namespace session
} // namespace mindspore

View File

@ -61,8 +61,7 @@ class GPUSession : public SessionBasic {
void AllocateMemory(KernelGraph *kernel_graph) const;
void RunOpAllocateMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const;
void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
void RunOpClearMemory(KernelGraph *kernel_graph) const;
@ -82,6 +81,10 @@ class GPUSession : public SessionBasic {
void PreLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const;
};
using GPUSessionPtr = std::shared_ptr<GPUSession>;
MS_REG_SESSION(kGPUDevice, GPUSession);

View File

@ -14,9 +14,11 @@
* limitations under the License.
*/
#include "backend/session/session_basic.h"
#include <utility>
#include <algorithm>
#include <set>
#include <unordered_map>
#include <utility>
#include "c_ops/primitive_c.h"
#include "ir/manager.h"
@ -1606,6 +1608,12 @@ void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
}
void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
}
void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);

View File

@ -22,6 +22,7 @@
#include <utility>
#include <memory>
#include <map>
#include <set>
#include "backend/session/session_context.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -49,7 +50,6 @@ struct OpRunInfo {
std::string op_name;
PrimitivePtr primitive;
AbstractBasePtr abstract;
ValuePtr value = nullptr;
bool is_dynamic_shape = false;
bool is_auto_mixed_precision = false;
std::string next_op_name = "";
@ -79,6 +79,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs);
void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
@ -138,6 +139,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
friend class RunGraphTask;
friend class BuildOpTask;
friend class RunOpTask;
friend class RunOpsInGraphTask;
virtual bool IsSupportSummary() { return true; }
virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
@ -155,6 +157,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
const std::vector<int64_t> &tensors_mask) {}
virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {}
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {}
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
virtual void SetSummaryNodes(KernelGraph *graph);

View File

@ -281,24 +281,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint;
}
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
if (value->isa<tensor::Tensor>()) {
auto tnode = value->cast<tensor::TensorPtr>();
if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) {
MS_LOG(DEBUG) << "Set tensor" << tnode->device_address();
(*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address());
}
}
if (value->isa<ValueTuple>()) {
auto tuple = value->cast<ValueTuplePtr>();
for (size_t i = 0; i < tuple->size(); i++) {
MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString();
TensorSetAddress((*tuple)[i], tuple_tensors);
}
}
}
ValuePtr GenNewTensorInner(const ValuePtr &value) {
std::vector<ValuePtr> value_list;
if (value->isa<tensor::Tensor>()) {
@ -328,7 +310,6 @@ ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, co
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
auto forward = cnode_morph->forward().first;
auto forward_id = cnode_morph->forward().second;
if (forward == nullptr) {
return;
}
@ -337,6 +318,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
return;
}
auto fg = GetValueNode<FuncGraphPtr>(input);
// {prim::maketuple, forward_output, bprop_graph}
auto output = fg->output();
if (!output->isa<CNode>()) {
return;
@ -350,25 +332,22 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
if (!IsValueNode<FuncGraph>(input_fg)) {
return;
}
std::map<std::string, tensor::TensorPtr> tuple_tensors;
// replace forward output with value node
auto equivdout = cnode_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(equivdout);
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = Manage({fg, func_graph}, false);
auto ref_size = manager->node_users()[equivdout].size();
auto forward_value = forward;
if (!forward_id.empty() && ref_size > 1) {
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors);
}
forward_value = GenNewTensor(manager, equivdout, forward);
auto forward_value = GenNewTensor(manager, equivdout, forward);
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward_value);
value_node->set_has_new_value(true);
manager->Replace(equivdout, value_node);
// replace input object with value node
auto paras = fg->parameters();
auto inputs_value = cnode_morph->inputs_value();
if (inputs_value.size() == 0) {
if (inputs_value.empty()) {
return;
}
if (inputs_value.size() != paras.size()) {
@ -379,10 +358,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
auto input_value = inputs_value[i];
if (para_ref_size > 0 && input_value.first != nullptr) {
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
auto inst = pynative::PynativeExecutor::GetInstance();
if (!input_value.second.empty()) {
inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors);
}
auto input_value_node = NewValueNode(input_value.first);
input_value_node->set_has_new_value(true);
manager->Replace(paras[i], input_value_node);
@ -394,30 +369,19 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
res->set_func_graph(fg);
PynativeElimOpt(res);
auto out = fg->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out);
auto c_input = out->input(1);
MS_EXCEPTION_IF_NULL(c_input);
if (!c_input->isa<ValueNode>()) {
return;
}
auto out_node = c_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(out_node);
out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
// clear resource
cnode_morph->clear_inputs_value();
if (tuple_tensors.size() != 0) {
MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4);
for (auto &g : manager->func_graphs()) {
for (auto &node : g->value_nodes()) {
MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString();
auto vnode = node.first->cast<ValueNodePtr>()->value();
TensorSetAddress(vnode, &tuple_tensors);
}
}
}
fg->ClearAllManagerInfo();
func_graph->ClearAllManagerInfo();
return;
}
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {

View File

@ -298,14 +298,29 @@ class PynativeEliminater : public OptimizerCaller {
return out;
}
void OnlySaveAbstractInfo(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
value_node->set_value(MakeValue(new_tensor));
}
}
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
PatternNode<AnfNodePtr> symbol_str_vnode, c_vnode, zeros_like_vnode, getitem_vnode, arg, arg1;
PatternNode<AnfNodePtr> symbol_str_vnode;
PatternNode<AnfNodePtr> c_vnode;
PatternNode<AnfNodePtr> zeros_like_vnode;
PatternNode<AnfNodePtr> arg;
auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode);
auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode);
auto pattern = PCNode(getattr, arg);
// {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
if ((pattern).TryCapture(node) &&
(CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
@ -320,8 +335,8 @@ class PynativeEliminater : public OptimizerCaller {
}
}
}
MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
// {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode);
auto pattern1 = PCNode(resolve1, arg);
@ -338,7 +353,13 @@ class PynativeEliminater : public OptimizerCaller {
}
}
}
// {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout}
PatternNode<AnfNodePtr> binop_grad_common;
PatternNode<AnfNodePtr> getitem_vnode;
PatternNode<AnfNodePtr> arg1;
PatternNode<AnfNodePtr> arg2;
PatternNode<AnfNodePtr> arg3;
PatternNode<AnfNodePtr> arg4;
// resolve(CommonOPS, getitem)((tensors), 3)
auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode);
auto pattern2 = PCNode(resolve2, arg, arg1);

View File

@ -51,21 +51,19 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
std::string op_name;
std::string op_index;
std::string prim_id;
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract;
bool is_dynamic_shape = false;
ValuePtr value = nullptr;
py::list op_inputs;
py::dict op_attrs;
std::vector<bool> inputs_mask;
bool is_dynamic_shape = false;
std::string next_op_name = "";
bool is_mixed_precision_cast = false;
size_t next_input_index = 0;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};

View File

@ -149,12 +149,6 @@ static std::string GetId(const py::object &obj) {
return py::cast<std::string>(ret);
}
static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
auto id = GetId(op_exec_info->py_primitive->GetPyObj());
op_exec_info->prim_id = id;
return id;
}
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
for (size_t i = 0; i < dtypes.size(); ++i) {
@ -260,24 +254,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString();
}
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr;
}
auto op_exec_info = std::make_shared<OpExecInfo>();
MS_EXCEPTION_IF_NULL(op_exec_info);
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = args[PY_INPUTS];
return op_exec_info;
}
std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(op_exec_info);
@ -580,7 +556,7 @@ py::tuple RunOp(const py::args &args) {
auto executor = PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(executor);
MS_LOG(DEBUG) << "RunOp start " << args.size();
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args);
OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);
try {
return executor->RunOpInner(op_exec_info);
} catch (const py::error_already_set &ex) {
@ -608,16 +584,17 @@ py::tuple RunOp(const py::args &args) {
}
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
auto prim = op_exec_info->py_primitive;
auto name = op_exec_info->op_name;
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) {
return RunOpWithInitBackendPolicy(op_exec_info);
}
// make cnode for building grad graph if grad flag is set.
abstract::AbstractBasePtrList args_spec_list;
std::vector<bool> op_masks;
auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list);
auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list);
op_exec_info->inputs_mask = op_masks;
// get output abstract info
bool is_find = false;
auto prim = op_exec_info->py_primitive;
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
auto abs_list = prim_abs_list_[prim->id()];
MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
@ -629,7 +606,6 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
is_find = true;
}
}
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) {
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
@ -648,11 +624,10 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
}
op_exec_info->inputs_mask = op_masks;
// infer output value for const prim
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->abstract != nullptr) {
MS_LOG(DEBUG) << "Run op infer " << name << " " << op_exec_info->abstract->ToString();
MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString();
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract);
if (!output["value"].is_none()) {
py::tuple value_ret(1);
@ -665,7 +640,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
return value_ret;
}
}
// add output abstract info into cache
if (!is_find) {
// const_value need infer every step
auto &out = prim_abs_list_[prim->id()];
@ -674,13 +649,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
out[args_spec_list].attrs = prim->evaluate_added_attrs();
MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
}
if (PynativeExecutor::GetInstance()->grad_flag()) {
op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info);
} else {
(void)GetOpId(op_exec_info);
}
// run op with selected backend
auto result = RunOpWithInitBackendPolicy(op_exec_info);
py::object out_real = result;
if (result.size() == 1) {
@ -689,13 +658,38 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
}
std::string obj_id = GetId(out_real);
node_abs_map_[obj_id] = op_exec_info->abstract;
PynativeExecutor::GetInstance()->SaveOutputNodeMap(obj_id, out_real, cnode);
if (cnode != nullptr) {
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast<CNodePtr>(), result);
}
SaveOutputNodeMap(obj_id, out_real, cnode);
SaveAllResult(op_exec_info, cnode, out_real);
// Update the abstract and device address of value node with tensor in grad graph
UpdateAbstractAndDeviceAddress(op_exec_info, out_real);
return result;
}
OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Three args are needed by RunOp";
return nullptr;
}
auto op_exec_info = std::make_shared<OpExecInfo>();
auto op_name = py::cast<std::string>(args[PY_NAME]);
op_exec_info->op_name = op_name;
if (grad_flag_) {
MS_EXCEPTION_IF_NULL(resource_);
int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>();
op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]);
op_index_map_[op_name]++;
}
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
MS_EXCEPTION_IF_NULL(prim);
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
op_exec_info->prim_id = GetId(prim->GetPyObj());
op_exec_info->py_primitive = prim;
op_exec_info->op_inputs = args[PY_INPUTS];
return op_exec_info;
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
@ -997,6 +991,56 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
return node;
}
void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) {
MS_EXCEPTION_IF_NULL(op_exec_info);
if (!grad_flag_) {
return;
}
auto op_index = op_exec_info->op_index;
auto output_value = PyAttrValue(out_real);
MS_EXCEPTION_IF_NULL(output_value);
std::vector<tensor::TensorPtr> output_tensors;
TensorValueToTensor(output_value, &output_tensors);
if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) {
// first step
std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) {
op_index_with_tensor_id_[op_index].emplace_back(tensor->id());
});
return;
}
const auto &tensor_id_list = op_index_with_tensor_id_[op_index];
for (size_t i = 0; i < tensor_id_list.size(); ++i) {
auto tensor_id = tensor_id_list[i];
if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) {
auto &new_tensor = output_tensors[i];
auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id];
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) {
tensor->set_shape(new_tensor->shape());
tensor->set_data_type(new_tensor->data_type());
tensor->set_device_address(new_tensor->device_address());
});
}
}
}
void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
tensor_id_with_tensor_.clear();
const auto &func_graph = resource->func_graph();
const auto &value_node_list = func_graph->value_nodes();
for (const auto &elem : value_node_list) {
auto value_node = elem.first->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value_node->value(), &tensors);
for (const auto &tensor : tensors) {
if (tensor->device_address() != nullptr) {
tensor_id_with_tensor_[tensor->id()].emplace_back(tensor);
}
}
}
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) {
auto &out = graph_info_map_[curr_g_].node_map[obj_id];
if (out.second.size() == 1 && out.second[0] == -1) {
@ -1054,23 +1098,6 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str
return node;
}
ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
auto id = GetOpId(op_exec_info);
int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>();
auto op = std::to_string(graph_id) + id;
op.append(std::to_string(op_id_map_[id]));
auto iter = op_forward_map_.find(op);
if (iter != op_forward_map_.end()) {
++op_id_map_[id];
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
return iter->second;
}
if (!first_grad_step_) {
++op_id_map_[id];
}
return nullptr;
}
void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real,
const AnfNodePtr &cnode) {
if (!grad_flag_ || graph_info_map_.empty()) {
@ -1093,16 +1120,16 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob
SetPyObjInGraphInfoMap(curr_g_, obj_id);
}
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
if (!grad_flag_ || op_exec_info->value != nullptr || cnode == nullptr) {
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node,
const py::object &out_real) {
if (!grad_flag_ || node == nullptr) {
return;
}
py::object out_real = out;
if (out.size() == 1) {
out_real = out[0];
}
auto value = PyAttrValue(out_real);
MS_EXCEPTION_IF_NULL(op_exec_info);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// save input object
size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
@ -1113,59 +1140,19 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
cnode->add_input_value(nullptr, "");
}
}
std::string id = GetOpId(op_exec_info);
int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>();
auto op_id = std::to_string(graph_id) + id;
op_id.append(std::to_string(op_id_map_[id]));
cnode->set_forward(value, op_id);
++op_id_map_[id];
// save output object
auto output_value = PyAttrValue(out_real);
MS_EXCEPTION_IF_NULL(output_value);
cnode->set_forward(output_value, op_exec_info->op_index);
auto out_id = GetId(out_real);
if (py::isinstance<py::tuple>(out_real)) {
auto tuple_item = py::cast<py::tuple>(out_real);
for (size_t i = 0; i < tuple_item.size(); i++) {
auto tuple_item_id = GetId(tuple_item[i]);
obj_to_forward_id_[tuple_item_id] = op_id;
obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index;
}
SaveOpForwardValue(op_id, value, nullptr);
}
obj_to_forward_id_[out_id] = op_id;
}
void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value,
std::map<std::string, tensor::TensorPtr> *t_map) {
if (op_forward_map_.find(id) != op_forward_map_.end()) {
// for one op have multi outputs but save only one tensor
if (op_forward_map_[id]->isa<ValueTuple>() && value->isa<tensor::Tensor>()) {
auto tuple = op_forward_map_[id]->cast<ValueTuplePtr>();
auto value_t = value->cast<tensor::TensorPtr>();
for (size_t i = 0; i < tuple->size(); i++) {
if ((*tuple)[i]->isa<tensor::Tensor>()) {
auto tuple_t = (*tuple)[i]->cast<tensor::TensorPtr>();
if (value_t->id() == tuple_t->id()) {
tuple_t->set_device_address(value_t->device_address());
MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString();
break;
}
}
}
}
if (value->isa<ValueTuple>() && t_map != nullptr) {
GenTupleMap(op_forward_map_[id]->cast<ValueTuplePtr>(), t_map);
}
MS_LOG(DEBUG) << "Save op forward value: "
<< "(" << id << "), " << op_forward_map_[id]->ToString();
return;
}
if (value->isa<ValueTuple>() && t_map == nullptr) {
// make cnode gen all tuple node and set device_address be null
op_forward_map_[id] = CleanTupleAddr(value->cast<ValueTuplePtr>());
} else {
op_forward_map_[id] = value;
}
MS_LOG(DEBUG) << "Save op forward value: "
<< "(" << id << "), " << value->ToString();
obj_to_forward_id_[out_id] = op_exec_info->op_index;
}
void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
@ -1307,10 +1294,13 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors);
// get graph info for checking it whether existing in the cache
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors);
session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive,
op_exec_info->abstract, op_exec_info->value,
op_exec_info->is_dynamic_shape, op_exec_info->is_mixed_precision_cast,
op_exec_info->next_op_name, op_exec_info->next_input_index};
session::OpRunInfo op_run_info = {op_exec_info->op_name,
op_exec_info->py_primitive,
op_exec_info->abstract,
op_exec_info->is_dynamic_shape,
op_exec_info->is_mixed_precision_cast,
op_exec_info->next_op_name,
op_exec_info->next_input_index};
session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, &input_tensors);
VectorRef outputs;
@ -1524,6 +1514,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
if (it != cell_resource_map_.end()) {
resource_ = it->second;
MS_EXCEPTION_IF_NULL(resource_);
op_index_map_.clear();
}
MS_LOG(DEBUG) << "Graph already compiled";
return;
@ -1571,7 +1562,8 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
resource_->results()[pipeline::kPynativeGraphId] = graph_id_++;
cell_resource_map_[cell_id] = resource_;
MS_LOG(DEBUG) << "New top graph for " << cell_id;
first_grad_step_ = true;
op_index_map_.clear();
op_index_with_tensor_id_.clear();
top_graph_cells_.emplace(cell_id);
}
@ -1770,6 +1762,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
MS_LOG(DEBUG) << "Start opt";
PynativeOptimizeAction(resource_);
SaveTensorsInValueNode(resource_);
TaskEmitAction(resource_);
ExecuteAction(resource_);
cell_graph_map_[cell_id].second = true;
@ -2021,7 +2014,6 @@ void PynativeExecutor::Clear(const std::string &flag) {
}
ConfigManager::GetInstance().ResetIterNum();
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
op_forward_map_.clear();
Clean();
}
node_abs_map_.clear();
@ -2033,9 +2025,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
first_grad_step_ = false;
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
node_abs_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_stack_);

View File

@ -83,13 +83,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_grad_flag(bool flag) { grad_flag_ = flag; }
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::tuple &args, const py::object &phase);
py::object CheckGraph(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
void SaveOpForwardValue(const std::string &id, const ValuePtr &value,
std::map<std::string, tensor::TensorPtr> *t_map);
// Call by python
void Clear(const std::string &flag = "");
@ -134,9 +133,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// replace for grad graph
ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real);
// Update the abstract and device address info of value node and tensors in bprop graph
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveTensorsInValueNode(const ResourcePtr &resource);
// construct grad graph
void PushCurrentGraphToStack();
@ -175,7 +176,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static int64_t graph_id_;
bool grad_flag_{false};
bool dynamic_cell_{false};
bool first_grad_step_{false};
bool grad_is_running{false};
// Used for construct grad graph
@ -199,9 +199,10 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>> df_builder_map_;
// used for runop and replace forward result of grad graph
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, size_t> op_index_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, std::vector<std::string>> op_index_with_tensor_id_;
std::unordered_map<std::string, std::vector<tensor::TensorPtr>> tensor_id_with_tensor_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional";

View File

@ -81,15 +81,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
UpdateRefNodeOutputMem(graph);
}
void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory();
RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph);
RunOpAssignOutputNodeMemory(pre_output_value, graph);
for (const auto &cnode : graph->execution_order()) {
RunOpAssignOutputMemory(cnode);
RunOpAssignWorkSpaceMemory(cnode);
@ -680,6 +678,52 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_LOG(INFO) << "AssignStaticMemoryValueNode end";
}
void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "SyncValueNodeDeviceAddr start";
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
continue;
}
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t index = 0; index < tensors.size(); index += 1) {
const auto &tensor = tensors[index];
if (tensor->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index,
value_node.get());
} else {
MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr.";
}
}
}
MS_LOG(INFO) << "SyncValueNodeDeviceAddr end";
}
void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "CleanValueNodeDeviceAddr start";
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) {
continue;
}
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (size_t index = 0; index < tensors.size(); index += 1) {
if (tensors[index]->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(nullptr, index, value_node.get());
}
}
}
MS_LOG(INFO) << "CleanValueNodeDeviceAddr end";
}
void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_manager_);

View File

@ -51,8 +51,7 @@ class KernelRuntime {
virtual ~KernelRuntime();
virtual bool Init() = 0;
virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph);
void RunOpClearMemory(const session::KernelGraph *graph);
static bool DumpDataEnabled();
static bool DumpDataEnabledIteration();
@ -67,6 +66,8 @@ class KernelRuntime {
const AddressPtrList &kernel_workspaces) const;
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph);
virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order);

View File

@ -18,13 +18,13 @@
#include <algorithm>
#include <vector>
#include "utils/log_adapter.h"
#include "backend/session/session_factory.h"
#include "ir/anf.h"
#include "pybind_api/ir/base_ref_py.h"
#include "utils/callbacks.h"
#include "utils/convert_utils.h"
#include "backend/session/session_factory.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "pybind_api/ir/base_ref_py.h"
#ifdef ENABLE_GE
#include "utils/callbacks_ge.h"
#endif
@ -83,10 +83,14 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std:
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
if (target != target_device_ && !target.empty()) {
other_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (!pynative_mode || target != "Ascend") {
if (target != target_device_ && !target.empty()) {
other_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
}
}
result.run = std::make_shared<RunFunc>(
[graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
@ -154,12 +158,19 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
PushInputTensor(arg, &inputs);
}
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
if (target != target_device_ && !target.empty()) {
other_sess_->RunGraphAsync(g, inputs, &outputs);
} else {
target_sess_->RunGraphAsync(g, inputs, &outputs);
if (pynative_mode && target == "Ascend") {
target_sess_->RunOpsInGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraphAsync(g, inputs, &outputs);
}
}
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();

View File

@ -134,7 +134,6 @@ class MulAdd(nn.Cell):
assert dout.asnumpy() == 1.0
return dout, y
class Ms_Cell(nn.Cell):
def __init__(self):
super(Ms_Cell, self).__init__()
@ -143,6 +142,19 @@ class Ms_Cell(nn.Cell):
def construct(self, x):
return self.relu(x)
def bprop(self, x, out, dout):
dout = Tensor(np.float32(0.0))
assert dout.shape == ()
return dout
class Ms_Cell_Change_Shape(nn.Cell):
def __init__(self):
super(Ms_Cell_Change_Shape, self).__init__()
self.relu = P.ReLU()
def construct(self, x):
return self.relu(x)
def bprop(self, x, out, dout):
dout = Tensor(np.ones([5, 5]).astype(np.float32))
assert dout.shape == (5, 5)
@ -186,6 +198,19 @@ def test_pynative_custom_bprop_and_Cell_MulAdd():
(Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape():
custom_cell = test_custom_cell_base()
ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape())
ms_Cell.bprop_debug = True
with pytest.raises(RuntimeError) as ex:
grad_all(ms_Cell)(Tensor(1, mstype.float32))
assert "Shapes of input and parameter are different, input index" in str(ex.value)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -194,5 +219,5 @@ def test_pynative_custom_bprop_and_Cell_Ms_Cell():
custom_cell = test_custom_cell_base()
ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell())
ms_Cell.bprop_debug = True
assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),)
assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),)

View File

@ -65,7 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
py::none py_none;
py::args args = py::make_tuple(conv_obj, op_name, op_inputs);
py::list args_input = args[PY_INPUTS];
return GenerateOpExecInfo(args);
return PynativeExecutor::GetInstance()->GenerateOpExecInfo(args);
}
TEST_F(TestPynativeExecute, TestCreateContext) {