Support RunOpsInGraph on CPU&GPU in pynative mode

This commit is contained in:
tanghuikang 2021-01-27 15:50:42 +08:00
parent a50a65adf9
commit 6f2cd92aba
7 changed files with 463 additions and 439 deletions

View File

@ -143,183 +143,6 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
root_graph->set_output(make_tuple);
}
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
auto &node = node_output_pair.first;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
<< "]";
// if node is a value node, no need sync addr from device to host
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph->inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
}
(*output_indexes)[node_output_pair].emplace_back(indexes);
BaseRef output_placeholder = std::make_shared<BaseRef>();
return output_placeholder;
}
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
MS_LOG(INFO) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
// special handle for maketuple
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
auto cnode = item_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
std::vector<size_t> cur_index = indexes;
cur_index.emplace_back(i - 1);
auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
ret.push_back(out);
}
return ret;
}
// if is graph return nothing ,the function should return a null anylist
size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
if (size == 0) {
return VectorRef();
}
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
}
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_indexes);
auto anf_outputs = kernel_graph->outputs();
size_t index = 0;
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Create node output placeholder[" << item->DebugString() << "]";
std::vector<size_t> indexes{index++};
outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
}
}
void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
const auto &node = kernel_with_index.first;
if (node->isa<CNode>()) {
(*ref_count)[kernel_with_index] += 1;
}
}
}
}
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
std::map<AnfNodePtr, size_t> *parameter_index) {
size_t index = 0;
for (const auto &input_node : graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
for (const auto &param : params) {
if (index >= inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
<< ", input size: " << inputs.size();
}
const auto &input = inputs[index];
// Check shape of input and parameter
const auto &input_shape = input->shape();
const auto &param_shape = AnfAlgo::GetOutputInferShape(param, 0);
if (input_shape.size() != param_shape.size()) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
for (size_t i = 0; i < input_shape.size(); i += 1) {
if (input_shape[i] < 0 || static_cast<size_t>(input_shape[i]) != param_shape[i]) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
}
parameter_index->emplace(param, index++);
}
}
}
TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (output_index >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
}
auto tensor_value = value_tuple->value()[output_index];
if (tensor_value->isa<tensor::Tensor>()) {
return tensor_value->cast<tensor::TensorPtr>();
}
} else if (value->isa<tensor::Tensor>()) {
if (output_index != 0) {
MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
}
return value->cast<TensorPtr>();
}
return nullptr;
}
TensorPtr GetParameterOutputTensor(const AnfNodePtr &node, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<Parameter>()) {
return nullptr;
}
const auto &iter = parameter_index.find(node);
if (iter == parameter_index.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
}
const size_t index = iter->second;
if (index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
<< ", input tensor size = " << graph_inputs.size();
}
return graph_inputs[index];
}
TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
const auto &iter = op_output.find(kernel_with_index);
if (iter == op_output.end()) {
MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
}
return iter->second;
}
TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
bool *output_is_weight) {
@ -332,144 +155,6 @@ TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index,
return iter->second.output_stub_tensor;
}
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
} else if (real_input->isa<CNode>()) {
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
input_tensor_info->input_kernel.insert(kernel_with_index);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
input_tensor_info->input_tensors_mask.emplace_back(kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
auto parameter = real_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
input_tensor_info->input_tensors_mask.emplace_back(parameter->has_default() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else if (real_input->isa<CNode>()) {
bool output_is_weight = false;
tensor = GetCNodeOutputStubTensor(kernel_with_index, node_output_info, &output_is_weight);
input_tensor_info->input_tensors_mask.emplace_back(output_is_weight ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
// Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
ref_iter->second -= 1;
if (ref_iter->second != 0) {
continue;
}
ref_count->erase(ref_iter);
auto output_iter = op_output_map->find(kernel_with_index);
if (output_iter == op_output_map->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
op_output_map->erase(output_iter);
}
}
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() != op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = output_indexes.find(kernel_with_index);
if (iter == output_indexes.end()) {
continue;
}
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
}
}
void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr &kernel,
std::map<KernelWithIndex, OutputTensorInfo> *op_output_info) {
MS_EXCEPTION_IF_NULL(single_op_graph);
@ -508,59 +193,6 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
(*op_output_info)[kernel_with_index] = output_tensor_info;
}
}
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(run_info);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
run_info->primitive = primitive;
run_info->op_name = primitive->name();
if (cnode->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
}
run_info->abstract = cnode->abstract();
}
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
auto prim = AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(prim);
const AbstractBasePtr &abstract = kernel->abstract();
MS_EXCEPTION_IF_NULL(abstract);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
GraphInfo graph_info;
// get input tensor info
for (const auto &tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_shape = tensor->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(tensor->data_type()) + "_");
if (tensor->device_address() != nullptr) {
const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
(void)graph_info.append(std::to_string(type_id) + "_");
const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
(void)graph_info.append(format + "_");
}
}
// get attr info
const auto &attr_map = prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
if (element.second->ToString().empty()) {
return;
}
(void)graph_info.append(element.second->ToString() + "_");
});
auto build_shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
for (size_t output_index = 0; output_index < output_num; output_index += 1) {
const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
(void)graph_info.append(std::to_string(output_type) + "_");
}
graph_info.append(prim->id());
return graph_info;
}
} // namespace
void AscendSession::Init(uint32_t device_id) { InitExecutor(kAscendDevice, device_id); }
@ -1028,8 +660,48 @@ KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info, const Gra
return graph;
}
void AscendSession::BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePtr, size_t> &parameter_index,
void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
input_tensor_info->input_tensors_mask.emplace_back(kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
auto parameter = real_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
input_tensor_info->input_tensors_mask.emplace_back(parameter->has_default() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else if (real_input->isa<CNode>()) {
bool output_is_weight = false;
tensor = GetCNodeOutputStubTensor(kernel_with_index, node_output_info, &output_is_weight);
input_tensor_info->input_tensors_mask.emplace_back(output_is_weight ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
void AscendSession::BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {
if (built_graph_id_.find(graph_id) == built_graph_id_.end()) {
return;
}
auto graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
std::map<KernelWithIndex, OutputTensorInfo> op_output_info;
std::vector<CNodePtr> kernels;
@ -1079,44 +751,7 @@ void AscendSession::BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePt
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << single_op_graph.second;
}
}
}
void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_LOG(INFO) << "Start!";
auto kernel_graph = GetGraph(graph_id);
std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);
if (built_graph_id_.find(graph_id) == built_graph_id_.end()) {
BuildOpsInGraph(kernel_graph.get(), parameter_index, inputs);
built_graph_id_.insert(graph_id);
}
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
// Generate input tensors, tensor masks and input kernel with index
InputTensorInfo input_tensor_info;
GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
// Get OpRunInfo and GraphInfo
OpRunInfo run_info;
GetSingleOpRunInfo(kernel, &run_info);
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
// Build and run current single op
VectorRef op_outputs;
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
input_tensor_info.input_tensors_mask);
// Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs);
}
MS_LOG(INFO) << "Finish!";
built_graph_id_.insert(graph_id);
}
// compile graph steps

View File

@ -35,16 +35,6 @@
namespace mindspore {
namespace session {
enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 };
struct InputTensorInfo {
std::vector<tensor::TensorPtr> input_tensors;
std::vector<int64_t> input_tensors_mask;
std::set<KernelWithIndex> input_kernel;
};
struct OutputTensorInfo {
tensor::TensorPtr output_stub_tensor;
bool is_weight;
};
class AscendSession : public SessionBasic {
public:
@ -68,8 +58,8 @@ class AscendSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) override;
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) override;
private:
// compile child graph when session have multiple child graphs
@ -112,7 +102,7 @@ class AscendSession : public SessionBasic {
const std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id) const;
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// sync intial tensors' data to device
// sync initial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// create parameter to receive data from multiple branch output
@ -128,8 +118,10 @@ class AscendSession : public SessionBasic {
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
void BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs);
void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info);
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs

View File

@ -333,8 +333,10 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
// Update Graph Dynamic Shape Attr.
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
graph->UpdateGraphDynamicAttr();
// Hide NopOp from execution graph
opt::HideNopNode(graph.get());
// Hide NopOp from execution graph in graph mode
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
opt::HideNopNode(graph.get());
}
// Build kernel if node is cnode
BuildKernel(graph);
// Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph

View File

@ -25,6 +25,7 @@
#include "abstract/utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "base/core_ops.h"
#include "base/base_ref_utils.h"
#include "common/trans.h"
#include "utils/config_manager.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -37,7 +38,6 @@
#include "ir/func_graph_cloner.h"
#include "utils/utils.h"
#include "debug/anf_ir_dump.h"
#include "mindspore/core/base/base_ref_utils.h"
#include "utils/trace_base.h"
#ifdef ENABLE_DUMP_IR
#include "debug/rdr/running_data_recorder.h"
@ -393,6 +393,200 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
}
return true;
}
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
std::map<AnfNodePtr, size_t> *parameter_index) {
size_t index = 0;
for (const auto &input_node : graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
for (const auto &param : params) {
if (index >= inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
<< ", input size: " << inputs.size();
}
const auto &input = inputs[index];
// Check shape of input and parameter
const auto &input_shape = input->shape();
const auto &param_shape = AnfAlgo::GetOutputInferShape(param, 0);
if (input_shape.size() != param_shape.size()) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
for (size_t i = 0; i < input_shape.size(); i += 1) {
if (input_shape[i] < 0 || static_cast<size_t>(input_shape[i]) != param_shape[i]) {
MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
<< ", parameter: " << param->fullname_with_scope();
}
}
parameter_index->emplace(param, index++);
}
}
}
BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
auto &node = node_output_pair.first;
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
<< "]";
// if node is a value node, no need sync addr from device to host
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
}
if (graph->inputs()[input_idx] == node) {
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
}
(*output_indexes)[node_output_pair].emplace_back(indexes);
BaseRef output_placeholder = std::make_shared<BaseRef>();
return output_placeholder;
}
BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<size_t> &indexes,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(output_indexes);
MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
MS_LOG(INFO) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
// special handle for maketuple
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
auto cnode = item_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
std::vector<size_t> cur_index = indexes;
cur_index.emplace_back(i - 1);
auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
ret.push_back(out);
}
return ret;
}
// if is graph return nothing ,the function should return a null anylist
size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
if (size == 0) {
return VectorRef();
}
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
}
void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors,
VectorRef *outputs,
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(output_indexes);
auto anf_outputs = kernel_graph->outputs();
size_t index = 0;
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Create node output placeholder[" << item->DebugString() << "]";
std::vector<size_t> indexes{index++};
outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
}
}
void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
MS_EXCEPTION_IF_NULL(graph);
for (const auto &kernel : graph->execution_order()) {
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
const auto &input = kernel->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
const auto &node = kernel_with_index.first;
if (node->isa<CNode>()) {
(*ref_count)[kernel_with_index] += 1;
}
}
}
}
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
auto ref_iter = ref_count->find(kernel_with_index);
if (ref_iter == ref_count->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
// Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
ref_iter->second -= 1;
if (ref_iter->second != 0) {
continue;
}
ref_count->erase(ref_iter);
auto output_iter = op_output_map->find(kernel_with_index);
if (output_iter == op_output_map->end()) {
MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
<< kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
}
op_output_map->erase(output_iter);
}
}
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
}
size_t out_index = 0;
for (const auto &output_tensor : output_tensors) {
auto kernel_with_index = make_pair(kernel, out_index++);
if (ref_count.find(kernel_with_index) != ref_count.end()) {
(*op_output_map)[kernel_with_index] = output_tensor;
}
const auto &iter = output_indexes.find(kernel_with_index);
if (iter == output_indexes.end()) {
continue;
}
const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
for (const auto &ref_indexes : multiple_ref_indexes) {
size_t n = 0;
const VectorRef *cur_vector_ref = outputs;
for (; n < ref_indexes.size() - 1; n += 1) {
size_t index = ref_indexes.at(n);
if (index >= cur_vector_ref->size()) {
MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
<< cur_vector_ref->size();
}
const BaseRef &base_ref = (*cur_vector_ref)[index];
if (!utils::isa<VectorRef>(base_ref)) {
MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
}
cur_vector_ref = &utils::cast<VectorRef>(base_ref);
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
}
}
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -1058,6 +1252,148 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return graph;
}
GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
auto prim = AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(prim);
const AbstractBasePtr &abstract = kernel->abstract();
MS_EXCEPTION_IF_NULL(abstract);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
GraphInfo graph_info;
// get input tensor info
for (const auto &tensor : input_tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_shape = tensor->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(tensor->data_type()) + "_");
if (tensor->device_address() != nullptr) {
const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
(void)graph_info.append(std::to_string(type_id) + "_");
const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
(void)graph_info.append(format + "_");
}
}
// get attr info
const auto &attr_map = prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
if (element.second->ToString().empty()) {
return;
}
(void)graph_info.append(element.second->ToString() + "_");
});
auto build_shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
for (size_t output_index = 0; output_index < output_num; output_index += 1) {
const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
(void)graph_info.append(std::to_string(output_type) + "_");
}
graph_info.append(prim->id());
return graph_info;
}
void SessionBasic::GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(run_info);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
run_info->primitive = primitive;
run_info->op_name = primitive->name();
if (cnode->abstract() == nullptr) {
MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
}
run_info->abstract = cnode->abstract();
}
TensorPtr SessionBasic::GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (output_index >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
}
auto tensor_value = value_tuple->value()[output_index];
if (tensor_value->isa<tensor::Tensor>()) {
return tensor_value->cast<tensor::TensorPtr>();
}
} else if (value->isa<tensor::Tensor>()) {
if (output_index != 0) {
MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
}
return value->cast<TensorPtr>();
}
return nullptr;
}
TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<Parameter>()) {
return nullptr;
}
const auto &iter = parameter_index.find(node);
if (iter == parameter_index.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
}
const size_t index = iter->second;
if (index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
<< ", input tensor size = " << graph_inputs.size();
}
return graph_inputs[index];
}
TensorPtr SessionBasic::GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
const auto &iter = op_output.find(kernel_with_index);
if (iter == op_output.end()) {
MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
}
return iter->second;
}
void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
} else if (real_input->isa<CNode>()) {
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
input_tensor_info->input_kernel.insert(kernel_with_index);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
@ -1812,6 +2148,42 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
}
void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_LOG(INFO) << "Start!";
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes;
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);
BuildOpsInGraph(graph_id, parameter_index, inputs);
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
// Generate input tensors, tensor masks and input kernel with index
InputTensorInfo input_tensor_info;
GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
// Get OpRunInfo and GraphInfo
OpRunInfo run_info;
GetSingleOpRunInfo(kernel, &run_info);
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
// Build and run current single op
VectorRef op_outputs;
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
input_tensor_info.input_tensors_mask);
// Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs);
}
MS_LOG(INFO) << "Finish!";
}
void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);

View File

@ -59,8 +59,21 @@ struct OpRunInfo {
size_t next_input_index = 0;
#endif
};
struct InputTensorInfo {
std::vector<tensor::TensorPtr> input_tensors;
std::vector<int64_t> input_tensors_mask;
std::set<KernelWithIndex> input_kernel;
};
struct OutputTensorInfo {
tensor::TensorPtr output_stub_tensor;
bool is_weight;
};
using OpRunInfoPtr = std::shared_ptr<OpRunInfo>;
class Executor;
class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
public:
SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) {
@ -163,8 +176,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {}
virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {}
void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
virtual void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {}
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
virtual void SetSummaryNodes(KernelGraph *graph);
@ -184,6 +198,18 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask, bool is_ascend = false);
// Generate graph info for a single op graph
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors);
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info);
tensor::TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index);
tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs);
tensor::TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output);
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);

View File

@ -351,7 +351,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_enable_dynamic_mem = context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL);
bool is_enable_pynative_infer = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
bool is_pynative_mode = (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
if (is_enable_dynamic_mem && !is_pynative_mode && !is_enable_pynative_infer) {
auto graph_id = graph->graph_id();
auto iter = mem_swap_map_.find(graph_id);
if (iter == mem_swap_map_.end()) {
@ -851,7 +852,7 @@ void GPUKernelRuntime::UpdateHostSwapInQueue(const DeviceAddressPtr device_addre
MS_LOG(WARNING) << "Unexpected device address status: " << status;
break;
default:
MS_LOG(EXCEPTION) << "Invaild device address status: " << status;
MS_LOG(EXCEPTION) << "Invalid device address status: " << status;
}
}

View File

@ -160,19 +160,15 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
PushInputTensor(arg, &inputs);
}
VectorRef outputs;
// Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
if (target != target_device_ && !target.empty()) {
other_sess_->RunGraphAsync(g, inputs, &outputs);
if (pynative_mode) {
exe_session->RunOpsInGraph(g, inputs, &outputs);
} else {
if (pynative_mode && target == "Ascend") {
target_sess_->RunOpsInGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraphAsync(g, inputs, &outputs);
}
exe_session->RunGraphAsync(g, inputs, &outputs);
}
MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();