Pre build all ops in bprop graph under PyNative

This commit is contained in:
tanghuikang 2021-01-18 11:30:33 +08:00
parent 6899c46ffd
commit 5dc66a82ce
4 changed files with 267 additions and 60 deletions

View File

@ -257,10 +257,78 @@ void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr>
}
}
TensorPtr GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (output_index >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
}
auto tensor_value = value_tuple->value()[output_index];
if (tensor_value->isa<tensor::Tensor>()) {
return tensor_value->cast<tensor::TensorPtr>();
}
} else if (value->isa<tensor::Tensor>()) {
if (output_index != 0) {
MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
}
return value->cast<TensorPtr>();
}
return nullptr;
}
TensorPtr GetParameterOutputTensor(const AnfNodePtr &node, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<Parameter>()) {
return nullptr;
}
const auto &iter = parameter_index.find(node);
if (iter == parameter_index.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
}
const size_t index = iter->second;
if (index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
<< ", input tensor size = " << graph_inputs.size();
}
return graph_inputs[index];
}
TensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
const auto &iter = op_output.find(kernel_with_index);
if (iter == op_output.end()) {
MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
}
return iter->second;
}
TensorPtr GetCNodeOutputStubTensor(const KernelWithIndex &kernel_with_index,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
bool *output_is_weight) {
MS_EXCEPTION_IF_NULL(output_is_weight);
const auto &iter = node_output_info.find(kernel_with_index);
if (iter == node_output_info.end()) {
MS_LOG(EXCEPTION) << "Can not find output stub tensor of cnode " << kernel_with_index.first->DebugString();
}
*output_is_weight = iter->second.is_weight;
return iter->second.output_stub_tensor;
}
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
@ -268,43 +336,11 @@ void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, te
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
auto value_node = real_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (kernel_with_index.second >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index " << kernel_with_index.second << "is out of value tuple range";
}
auto tensor_value = value_tuple->value()[kernel_with_index.second];
if (tensor_value->isa<tensor::Tensor>()) {
tensor = tensor_value->cast<tensor::TensorPtr>();
}
} else if (value->isa<tensor::Tensor>()) {
if (kernel_with_index.second != 0) {
MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << kernel_with_index.second;
}
tensor = GetValueNode<TensorPtr>(value_node);
}
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
} else if (real_input->isa<Parameter>()) {
const auto &iter = parameter_index.find(real_input);
if (iter == parameter_index.end()) {
MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, node = " << cnode->DebugString();
}
const size_t index = iter->second;
if (index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = "
<< cnode->DebugString() << "input tensor size = " << graph_inputs.size();
}
tensor = graph_inputs[index];
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
} else if (real_input->isa<CNode>()) {
const auto &iter = op_output.find(kernel_with_index);
if (iter == op_output.end()) {
MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << real_input->DebugString();
}
tensor = iter->second;
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
input_tensor_info->input_kernel.insert(kernel_with_index);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
@ -318,9 +354,48 @@ void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, te
}
}
void GetOpInputStubTensors(const CNodePtr &cnode, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
for (size_t i = 1; i < cnode->inputs().size(); i += 1) {
const auto &input = cnode->input(i);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) {
tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
input_tensor_info->input_tensors_mask.emplace_back(kParameterDataTensorMask);
} else if (real_input->isa<Parameter>()) {
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
auto parameter = real_input->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
input_tensor_info->input_tensors_mask.emplace_back(parameter->has_default() ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else if (real_input->isa<CNode>()) {
bool output_is_weight = false;
tensor = GetCNodeOutputStubTensor(kernel_with_index, node_output_info, &output_is_weight);
input_tensor_info->input_tensors_mask.emplace_back(output_is_weight ? kParameterWeightTensorMask
: kParameterDataTensorMask);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
MS_EXCEPTION_IF_NULL(tensor);
MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
<< real_input->fullname_with_scope() << "-" << kernel_with_index.second;
input_tensor_info->input_tensors.emplace_back(tensor);
}
}
void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
MS_EXCEPTION_IF_NULL(ref_count);
MS_EXCEPTION_IF_NULL(op_output_map);
for (auto &kernel_with_index : input_kernel) {
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>()) {
continue;
}
@ -348,6 +423,9 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() != op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
@ -384,6 +462,45 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
}
}
void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr &kernel,
std::map<KernelWithIndex, OutputTensorInfo> *op_output_info) {
MS_EXCEPTION_IF_NULL(single_op_graph);
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_info);
OutputTensorInfo output_tensor_info;
size_t out_idx = 0;
for (const auto &output : single_op_graph->outputs()) {
const auto &output_kernel_with_index = AnfAlgo::VisitKernel(output, 0);
const auto &output_node = output_kernel_with_index.first;
const auto &output_index = output_kernel_with_index.second;
auto out_abstract = output_node->abstract();
MS_EXCEPTION_IF_NULL(out_abstract);
if (out_abstract->isa<abstract::AbstractTuple>()) {
out_abstract = out_abstract->cast<abstract::AbstractTuplePtr>()->elements()[output_index];
MS_EXCEPTION_IF_NULL(out_abstract);
}
abstract::AbstractTensorPtr tensor_abstract = out_abstract->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(tensor_abstract);
const auto &infer_type = AnfAlgo::GetOutputInferDataType(output_node, output_index);
tensor::TensorPtr stub_output_tensor =
std::make_shared<tensor::Tensor>(infer_type, tensor_abstract->shape()->shape(), nullptr);
const auto &output_type = AnfAlgo::GetOutputDeviceDataType(output_node, output_index);
const auto &output_shape = AnfAlgo::GetOutputDeviceShape(output_node, output_index);
const auto &output_format = AnfAlgo::GetOutputFormat(output_node, output_index);
tensor::DeviceInfo device_info;
device_info.format_ = output_format;
device_info.data_type_ = TypeIdToType(output_type);
stub_output_tensor->set_device_info(device_info);
device::DeviceAddressPtr device_address =
std::make_shared<device::ascend::AscendDeviceAddress>(nullptr, 0, output_format, output_type);
stub_output_tensor->set_device_address(device_address);
KernelWithIndex kernel_with_index = std::make_pair(kernel, out_idx++);
output_tensor_info.output_stub_tensor = stub_output_tensor;
output_tensor_info.is_weight = !dynamic_cast<device::KernelInfo *>(output_node->kernel_info())->is_feature_map();
(*op_output_info)[kernel_with_index] = output_tensor_info;
}
}
void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(run_info);
@ -396,8 +513,13 @@ void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
run_info->abstract = cnode->abstract();
}
GraphInfo GetSingleOpGraphInfo(const PrimitivePtr &prim, const std::vector<tensor::TensorPtr> &input_tensors) {
GraphInfo GetSingleOpGraphInfo(const CNodePtr &kernel, const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
auto prim = AnfAlgo::GetCNodePrimitive(kernel);
MS_EXCEPTION_IF_NULL(prim);
const AbstractBasePtr &abstract = kernel->abstract();
MS_EXCEPTION_IF_NULL(abstract);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
GraphInfo graph_info;
// get input tensor info
for (const auto &tensor : input_tensors) {
@ -415,11 +537,19 @@ GraphInfo GetSingleOpGraphInfo(const PrimitivePtr &prim, const std::vector<tenso
}
// get attr info
const auto &attr_map = prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
const auto &added_attr_map = prim->evaluate_added_attrs();
(void)std::for_each(added_attr_map.begin(), added_attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
if (element.second->ToString().empty()) {
return;
}
(void)graph_info.append(element.second->ToString() + "_");
});
auto build_shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
for (size_t output_index = 0; output_index < output_num; output_index += 1) {
const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
(void)graph_info.append(std::to_string(output_type) + "_");
}
graph_info.append(prim->id());
return graph_info;
}
@ -831,14 +961,8 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g
return;
}
// construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true);
const auto &graph = PreBuildOp(op_run_info, graph_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph);
opt::RunOpAscendBackendIRFusionOptimization(graph);
// kernel select
SelectKernel(*graph);
// optimize
RunOpHardwareOptimize(graph);
// init runtime resource
InitRuntimeResource();
// build kernel
@ -880,6 +1004,71 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!";
}
KernelGraphPtr AscendSession::PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask) {
// Construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask, true);
MS_EXCEPTION_IF_NULL(graph);
opt::RunOpAscendBackendIRFusionOptimization(graph);
SelectKernel(*graph);
RunOpHardwareOptimize(graph);
return graph;
}
void AscendSession::BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs) {
MS_EXCEPTION_IF_NULL(graph);
std::map<KernelWithIndex, OutputTensorInfo> op_output_info;
std::vector<CNodePtr> kernels;
std::unordered_map<KernelGraphPtr, std::vector<GraphInfo>> single_op_graphs;
// Collect kernels need to be built in single op graphs
for (const auto &kernel : graph->execution_order()) {
// Generate fake input tensors, tensor masks and input kernel with index
InputTensorInfo input_tensor_info;
GetOpInputStubTensors(kernel, parameter_index, graph_inputs, op_output_info, &input_tensor_info);
// Get OpRunInfo and GraphInfo
OpRunInfo op_run_info;
GetSingleOpRunInfo(kernel, &op_run_info);
const GraphInfo &graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
const auto &single_op_graph_iter = run_op_graphs_.find(graph_info);
if (single_op_graph_iter != run_op_graphs_.end()) {
// if graph of same single op exists, the output tensor of current op should be generated
const auto &single_op_graph = single_op_graph_iter->second;
GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info);
continue;
}
const auto &single_op_graph =
PreBuildOp(op_run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask);
MS_EXCEPTION_IF_NULL(single_op_graph);
GenOpOutputStubTensor(single_op_graph, kernel, &op_output_info);
opt::HideNopNode(single_op_graph.get());
// The graph info could have been changed in PreBuildOp
const GraphInfo &new_graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
single_op_graphs.insert({single_op_graph, {graph_info, new_graph_info}});
const auto &execution_order = single_op_graph->execution_order();
std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
}
InitRuntimeResource();
// Compile all kernels parallel
BuildKernel(kernels);
// Some new kernel may be added after KernelBuildPreprocess, so collect and build kernels again
kernels.clear();
for (const auto &single_op_graph : single_op_graphs) {
device::ascend::KernelBuildPreprocess(single_op_graph.first.get());
const auto &execution_order = single_op_graph.first->execution_order();
std::copy(execution_order.begin(), execution_order.end(), std::back_inserter(kernels));
}
BuildKernel(kernels);
// Record single op graphs in run_op_graphs_ so that these graphs can be reused in BuildOpImpl
for (const auto &single_op_graph : single_op_graphs) {
for (const auto &graph_info : single_op_graph.second) {
run_op_graphs_[graph_info] = single_op_graph.first;
MS_LOG(DEBUG) << "Pre build op finished, graph info: " << single_op_graph.second;
}
}
}
void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_LOG(INFO) << "Start!";
@ -890,6 +1079,10 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes);
std::map<KernelWithIndex, size_t> cnode_ref;
GetRefCount(kernel_graph.get(), &cnode_ref);
if (built_graph_id_.find(graph_id) == built_graph_id_.end()) {
BuildOpsInGraph(kernel_graph.get(), parameter_index, inputs);
built_graph_id_.insert(graph_id);
}
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
@ -900,7 +1093,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
// Get OpRunInfo and GraphInfo
OpRunInfo run_info;
GetSingleOpRunInfo(kernel, &run_info);
GraphInfo graph_info = GetSingleOpGraphInfo(run_info.primitive, input_tensor_info.input_tensors);
GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
// Build and run current single op
VectorRef op_outputs;
@ -1034,10 +1227,14 @@ void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
}
void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
BuildKernel(kernel_graph->execution_order());
}
void AscendSession::BuildKernel(const std::vector<CNodePtr> &kernels) const {
MS_LOG(INFO) << "Start!";
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
auto ret = device::ascend::KernelBuild(kernel_graph.get());
auto ret = device::ascend::KernelBuild(kernels);
if (!ret) {
MS_LOG(EXCEPTION) << "Kernel build error.";
}

View File

@ -41,6 +41,11 @@ struct InputTensorInfo {
std::set<KernelWithIndex> input_kernel;
};
struct OutputTensorInfo {
tensor::TensorPtr output_stub_tensor;
bool is_weight;
};
class AscendSession : public SessionBasic {
public:
AscendSession() { final_graph_id_ = kInvalidGraphId; }
@ -79,6 +84,7 @@ class AscendSession : public SessionBasic {
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void BuildKernel(const std::vector<CNodePtr> &kernels) const;
void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
@ -119,7 +125,11 @@ class AscendSession : public SessionBasic {
void LoadGraphsToDbg(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
KernelGraphPtr PreBuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<int64_t> &tensors_mask);
void BuildOpsInGraph(KernelGraph *graph, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs);
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs
@ -128,6 +138,8 @@ class AscendSession : public SessionBasic {
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
// record graph ids of bp graphs that has been built in PyNative mode
std::set<GraphId> built_graph_id_;
};
MS_REG_SESSION(kAscendDevice, AscendSession);
} // namespace session

View File

@ -67,12 +67,11 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
return kernel_mod_ptr;
}
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
std::vector<AnfNodePtr> tbe_nodes;
std::vector<AnfNodePtr> akg_nodes;
std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
for (const auto &anf_node : kernels) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
continue;
@ -217,12 +216,9 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) {
return !(workspace_indexs.empty() && output_indexs.empty());
}
bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
bool KernelBuild(const std::vector<CNodePtr> &kernels) {
TbeUtils::LoadCache();
bool ret;
ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr);
return ret;
return device::ascend::KernelBuildParallelCompile(kernels);
}
std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo(

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_
#include <vector>
#include "backend/session/kernel_graph.h"
namespace mindspore {
@ -25,7 +27,7 @@ namespace ascend {
/**
* @brief kernel build for ascend.
*/
bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr);
bool KernelBuild(const std::vector<CNodePtr> &kernels);
/**
* @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn.
* Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph