!35673 Unify the process of ms_function and bprop

Merge pull request !35673 from caifubi/master-pynative-unified-ms-function-bprop
This commit is contained in:
i-robot 2022-06-14 01:05:19 +00:00 committed by Gitee
commit 04dcc3c8f6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 205 additions and 176 deletions

View File

@ -197,9 +197,66 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
}
}
// Move these function to anonymous namespace
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
for (auto value_element : value) {
MS_EXCEPTION_IF_NULL(value_element);
if (utils::isa<tensor::TensorPtr>(value_element)) {
(void)flatted_value->emplace_back(value_element);
} else if (utils::isa<ValueTuplePtr>(value_element)) {
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple_element);
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
} else {
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
}
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
MS_EXCEPTION_IF_NULL(flatted_value);
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
MS_EXCEPTION_IF_NULL(csr_tensor);
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
// Insert the front_node related tensor in the input_tensor.
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const KernelWithIndex &front_node,
std::vector<tensor::TensorPtr> *input_tensor) {
MS_EXCEPTION_IF_NULL(input_tensor);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node.first);
if (iter == parameters.end()) {
(void)((*input_tensor).emplace_back(nullptr));
@ -209,6 +266,31 @@ void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters
PushInputTensor(args[position], input_tensor, front_node.second);
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
MS_EXCEPTION_IF_NULL(input_tensor);
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
const size_t position = iter - parameters.begin();
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
// and there is no need to input a tensor.
if (position >= args.size()) {
MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size()
<< ".";
(void)input_tensor->emplace_back(nullptr);
return;
}
ValuePtrList flatted_value_tuple_value;
FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
}
auto input = flatted_value_tuple_value[index];
MS_EXCEPTION_IF_NULL(input);
auto tensor_input = input->cast<tensor::TensorPtr>();
input_tensor->push_back(tensor_input);
}
void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(op_run_info);
@ -334,6 +416,38 @@ bool EnablePyNativeSyncRunning() {
MS_EXCEPTION_IF_NULL(ms_context);
return ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
}
std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
const VectorRef &args) {
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensor;
MS_EXCEPTION_IF_NULL(kernel_graph);
for (const auto &input_node : kernel_graph->input_nodes()) {
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
if (element_pair.first) {
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor);
} else {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor);
}
}
(void)input_tensors.emplace_back(input_tensor);
}
// Input tensors of the control node.
std::vector<tensor::TensorPtr> input_tensor;
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
// Get inputs of control node which come from the host actor.
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
for (const auto &parameter : control_node_parameters) {
PushTensor(args, origin_parameters, parameter, &input_tensor);
}
(void)input_tensors.emplace_back(input_tensor);
return input_tensors;
}
} // namespace
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
@ -470,7 +584,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
PROF_START(compile_func_graph);
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
root_graph_ = root_graph.get();
root_graph_ = root_graph;
// Register a summary callback function, which is called in the final stages of summary.
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
@ -504,8 +618,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
// Construct the graph compiler info.
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
MS_EXCEPTION_IF_NULL(graph_compiler_info);
if (real_execution_mode_ == kGraphMode && graph_compiler_info->graphs_.size() != 0) {
if (real_execution_mode_ == kGraphMode && !graph_compiler_info->graphs_.empty()) {
// Transform graph to actor DAG, and schedule the actor DAG.
ParseControlNodes(*graph_compiler_info);
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
}
@ -815,84 +930,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
}
} // namespace
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
for (size_t i = 0; i < value.size(); ++i) {
auto value_element = value[i];
MS_EXCEPTION_IF_NULL(value_element);
if (utils::isa<tensor::TensorPtr>(value_element)) {
(void)flatted_value->emplace_back(value_element);
} else if (utils::isa<ValueTuplePtr>(value_element)) {
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple_element);
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
} else {
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
}
}
}
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
if (utils::isa<ValueSequencePtr>(arg)) {
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
MS_EXCEPTION_IF_NULL(value_sequence);
auto sequence_value = value_sequence->value();
for (auto &value : sequence_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
MS_EXCEPTION_IF_NULL(value_dict);
auto dict_value = value_dict->value();
for (auto &iter : dict_value) {
auto value = iter.second;
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>()) {
(void)flatted_value->emplace_back(value);
} else {
FlattenValue(value, flatted_value);
}
}
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
}
} else {
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
<< arg.ToString();
}
}
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
const size_t position = iter - parameters.begin();
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
// and there is no need to input a tensor.
if (position >= args.size()) {
MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size()
<< ".";
(void)input_tensor->emplace_back(nullptr);
return;
}
ValuePtrList flatted_value_tuple_value;
FlattenValue(args[position], &flatted_value_tuple_value);
if (index >= flatted_value_tuple_value.size()) {
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
}
auto input = flatted_value_tuple_value[index];
MS_EXCEPTION_IF_NULL(input);
auto tensor_input = input->cast<tensor::TensorPtr>();
input_tensor->push_back(tensor_input);
}
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) {
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) {
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
if (need_contruct_output) {
@ -911,8 +949,9 @@ void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *ou
}
void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
const VectorRef &args, VectorRef *outputs) {
WaitTaskFinish();
auto inputs = GetRunGraphInputs(graph_compiler_info, args);
MS_EXCEPTION_IF_NULL(graph_compiler_);
auto graphs = graph_compiler_info.graphs_;
if (graphs.size() > inputs.size()) {
@ -927,15 +966,22 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
const auto &graph = graphs[i];
MS_EXCEPTION_IF_NULL(graph);
graph->set_flag(kFlagPyNativeRunInGraph, true);
// Get real format from input tensor before compile graph
pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs[i]);
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front());
// The size of control_nodes is at least 1 since there is return node in the graph.
if (control_nodes_.size() == 1) {
MS_LOG(INFO) << "Replace parameter format";
// Input tensor is null if there is control-flow in graph.
// Need to get tensor after ParseControlNodes.
pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs.at(i));
}
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.at(i));
pynative::GraphAdapter::RemoveUnusedValueNodes(graph);
graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
// Clear front outputs after the outputs is cached.
graph->set_front_outputs({});
}
ParseControlNodes(graph_compiler_info);
actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
MS_EXCEPTION_IF_NULL(actor_set);
// Multithreading can cause spikes in memory usage and performance fluctuations
@ -952,9 +998,11 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph);
}
std::vector<std::vector<tensor::TensorPtr>> input_tensors = GetRunGraphInputs(graph_compiler_info, args);
// Release GIL and run actor DAG.
mindspore::ScopedLongRunning long_running;
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs);
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, input_tensors);
MS_EXCEPTION_IF_NULL(graph_compiler_);
graph_compiler_->Summary(graph_compiler_info.graphs_);
@ -977,12 +1025,15 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
}
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args,
VectorRef *outputs) {
WaitTaskFinish();
MS_EXCEPTION_IF_NULL(graph_compiler_);
auto &op_executor = runtime::OpExecutor::GetInstance();
op_executor.Register([this]() { BatchBuildCallback(); });
MS_EXCEPTION_IF_NULL(graph_compiler_);
const auto &graphs = graph_compiler_info.graphs_;
auto inputs = GetRunGraphInputs(graph_compiler_info, args);
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
const auto &graph = graphs[graph_index];
MS_EXCEPTION_IF_NULL(graph);
@ -1043,13 +1094,12 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
}
void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
VectorRef *outputs) {
const VectorRef &args, VectorRef *outputs) {
bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
if (contain_cut_graph) {
// Python API will be called in cut_graph, so we cannot release gil here.
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
RunGraphBySingleOp(graph_compiler_info, args, outputs);
} else {
bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
[](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); });
@ -1061,10 +1111,12 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph
// TODO(caifubi): Need to support condition: 1. Dynamic shape. 2. AutoParallel.
MS_EXCEPTION_IF_NULL(root_graph_);
if (root_graph_->has_flag(kFlagIsDynamicStructure) || is_dynamic_shape || is_parallel) {
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
if (root_graph_->has_flag(kFlagIsDynamicStructure) ||
// `ms_function + dynamic_shape` is already supported, so there is no need to execute RunGraphBySingleOp.
(is_dynamic_shape && root_graph_->has_flag(kFlagIsPynativeBpropGraph)) || is_parallel) {
RunGraphBySingleOp(graph_compiler_info, args, outputs);
} else {
RunGraphByActors(actor_info, graph_compiler_info, input_tensors, outputs);
RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
}
}
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
@ -1094,46 +1146,18 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
}
MS_EXCEPTION_IF_NULL(graph_iter->second);
const auto &graph_compiler_info = *(graph_iter->second);
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
// For pynative and graph mix execution.
WaitTaskFinish();
// Transform args to input tensors.
// Input tensors of the graph.
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensor;
MS_EXCEPTION_IF_NULL(kernel_graph);
for (const auto &input_node : kernel_graph->input_nodes()) {
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
if (element_pair.first) {
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor);
} else {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor);
}
}
(void)input_tensors.emplace_back(input_tensor);
}
// Input tensors of the control node.
std::vector<tensor::TensorPtr> input_tensor;
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
// Get inputs of control node which come from the host actor.
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
for (const auto &parameter : control_node_parameters) {
PushTensor(args, origin_parameters, parameter, &input_tensor);
}
(void)input_tensors.emplace_back(input_tensor);
// Run in the pynative mode.
MS_EXCEPTION_IF_NULL(outputs);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if (real_execution_mode_ == kPynativeMode) {
RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs);
RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
return;
}
auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
// Release python gil.
mindspore::ScopedLongRunning long_running;
// Run actor DAG.
@ -1298,22 +1322,7 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
++graph_index;
}
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
std::vector<KernelGraphPtr> kernel_graphs;
for (const auto &graph_id : sub_kernel_graphs_ids) {
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
(void)kernel_graphs.emplace_back(kernel_graph);
}
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
}
}
auto parser = std::make_shared<ControlNodeParser>();
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
runtime::KernelMapPosition outputs_order;
const auto &root_output =
@ -1635,5 +1644,25 @@ void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &ou
outputs->emplace_back(output_tensor);
}
}
void MindRTBackend::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
std::vector<KernelGraphPtr> kernel_graphs;
for (const auto &graph_id : sub_kernel_graphs_ids) {
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
(void)kernel_graphs.emplace_back(kernel_graph);
}
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
}
}
graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
graph_compile_info.device_contexts_, root_graph_,
func_graph_to_kernel_graphs);
}
} // namespace compile
} // namespace mindspore

View File

@ -143,7 +143,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph);
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph);
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
@ -160,6 +160,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
const std::vector<tensor::TensorPtr> *input_tensors,
bool need_erase);
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
// so the latest single op cache should be erased when cache list size exceeds threshold value.
void EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info, const KernelGraphPtr &graph);
@ -176,14 +178,13 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
OpRunInfo *op_run_info);
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors, VectorRef *outputs);
const VectorRef &args, VectorRef *outputs);
// Split complete kernel graph to single op graph in PyNative back
// propagation, then compile and run single op graph.
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
void RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args, VectorRef *outputs);
void RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
const VectorRef &args, VectorRef *outputs);
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
@ -209,7 +210,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
// Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode.
std::map<std::string, size_t> forward_op_output_tensor_id_;
FuncGraph *root_graph_;
FuncGraphPtr root_graph_;
GraphPartitionPtr graph_partition_;
std::shared_ptr<GraphCompiler> graph_compiler_;
std::string device_name_;

View File

@ -955,16 +955,11 @@ void OriginSetRunMode(const ResourcePtr &resource) {
}
}
void SetRunMode(const ResourcePtr &resource, bool pynative_switch_to_graph_mode) {
void SetRunMode(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) && common::GetEnv("DISABLE_ASCEND_MINDRT") != "1") {
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
if (pynative_switch_to_graph_mode) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
}
SetRunMode(resource->func_graph(), resource->GetResult(kBackend).cast<compile::BackendPtr>().get());
} else {
OriginSetRunMode(resource);
@ -995,22 +990,14 @@ bool TaskEmitAction(const ResourcePtr &resource) {
MS_LOG(WARNING) << "Multi device target is detected, CPU data is dumped in rank_0 directory";
}
DisableMindRT(resource);
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
bool pynative_switch_to_graph_mode =
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
(!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) &&
!is_parallel;
SetRunMode(resource, pynative_switch_to_graph_mode);
SetRunMode(resource);
auto bc_ptr = resource->GetResult(kBackend).cast<compile::BackendPtr>();
MS_EXCEPTION_IF_NULL(bc_ptr);
std::string backend = context_ptr->backend_policy();
// The graph compiling of mindRT.
if ((backend == kMsConvert) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
TaskEmitActionForMindRT(resource);
if (pynative_switch_to_graph_mode) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
}
return true;
}

View File

@ -455,7 +455,7 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
AllocateGraphMemory(NOT_NULL(graph));
LoadModel(NOT_NULL(graph));
AssignOutputNopNodeDeviceAddress(graph);
} else if (graph->is_dynamic_shape() && IsGraphMode()) {
} else if (graph->is_dynamic_shape() && (IsGraphMode() || graph->has_flag(kFlagPyNativeRunInGraph))) {
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
SetAtomicCleanToNodes(graph, node_atomics_); // graph mode may can do it too, instead of update execorder
opt::DynamicShapeConvertPass(graph);

View File

@ -325,7 +325,8 @@ void CPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (kernel_graph->is_dynamic_shape() && ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (kernel_graph->is_dynamic_shape() && (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
kernel_graph->has_flag(kFlagPyNativeRunInGraph))) {
opt::DynamicShapeConvertPass(kernel_graph);
}
}

View File

@ -225,7 +225,8 @@ void GPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
if (!kernel_graph->is_from_single_op()) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (kernel_graph->is_dynamic_shape() && ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
if (kernel_graph->is_dynamic_shape() && (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
kernel_graph->has_flag(kFlagPyNativeRunInGraph))) {
opt::DynamicShapeConvertPass(kernel_graph);
}
}

View File

@ -562,10 +562,14 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
session_->DumpGraphs({graph});
// Cache the backend graph output nodes to front nodes with output index.
auto backend_node = graph->output();
MS_EXCEPTION_IF_NULL(backend_node);
graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
// The graph is not compiled yet in PyNative Mode.
// Need to cache output latter when the graph is compiled.
if (!run_in_pynative) {
// Cache the backend graph output nodes to front nodes with output index.
auto backend_node = graph->output();
MS_EXCEPTION_IF_NULL(backend_node);
graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);

View File

@ -44,13 +44,12 @@ def test_adaptive_max_pool2d():
x = np.random.randn(32, 64, 128, 128).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
adaptive_max_pool2d = Net(64, True)
output1, argmax1 = adaptive_max_pool2d(Tensor(x))
output1, _ = adaptive_max_pool2d(Tensor(x))
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
adaptive_max_pool2d = Net(64, True)
output2, argmax2 = adaptive_max_pool2d(Tensor(x))
output2, _ = adaptive_max_pool2d(Tensor(x))
assert (output1.asnumpy() == output2.asnumpy()).all()
assert (argmax1.asnumpy() == argmax2.asnumpy()).all()
@pytest.mark.level0

View File

@ -20,6 +20,7 @@ from mindspore import context
from mindspore.ops import GradOperation
from mindspore.common import ParameterTuple
from mindspore.common.api import ms_function
from mindspore import ops as P
def forward_pre_hook_fn_bn(cell_id, inp):
@ -49,7 +50,7 @@ def forward_pre_hook_fn_multi_add(cell_id, inp):
def forward_hook_fn_conv(cell_id, inp, outp):
out = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")(outp)
out = P.Log()(outp)
return out
@ -101,6 +102,10 @@ class SingleNet(nn.Cell):
class SingleNetInConstruct(nn.Cell):
def __init__(self):
super(SingleNetInConstruct, self).__init__()
self.handle1 = None
self.handle2 = None
self.handle3 = None
self.handle4 = None
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.relu = nn.ReLU()
@ -262,13 +267,14 @@ class CompareMultiNet1(nn.Cell):
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
self.relu = nn.ReLU()
self.log = P.Log()
def construct(self, x, y):
x = self.mul(x + x, x * y)
x = self.conv(x)
x = self.log(x)
x = self.bn(x)
x = self.relu(x)
x = self.mul(x, x)
y = self.relu(x)
x = self.mul(y, x)
x = x + x
return x
@ -280,10 +286,11 @@ class CompareMultiNet2(nn.Cell):
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
self.relu = nn.ReLU()
self.log = P.Log()
def construct(self, x, y):
x = self.mul(x, y)
x = self.conv(x)
x = self.log(x)
x = self.bn(x)
x = self.mul(x, x)
return x