!35673 Unify the process of ms_function and bprop
Merge pull request !35673 from caifubi/master-pynative-unified-ms-function-bprop
This commit is contained in:
commit
04dcc3c8f6
|
@ -197,9 +197,66 @@ void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs,
|
|||
}
|
||||
}
|
||||
|
||||
// Move these function to anonymous namespace
|
||||
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
|
||||
MS_EXCEPTION_IF_NULL(flatted_value);
|
||||
for (auto value_element : value) {
|
||||
MS_EXCEPTION_IF_NULL(value_element);
|
||||
if (utils::isa<tensor::TensorPtr>(value_element)) {
|
||||
(void)flatted_value->emplace_back(value_element);
|
||||
} else if (utils::isa<ValueTuplePtr>(value_element)) {
|
||||
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple_element);
|
||||
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
|
||||
MS_EXCEPTION_IF_NULL(flatted_value);
|
||||
if (utils::isa<ValueSequencePtr>(arg)) {
|
||||
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||
auto sequence_value = value_sequence->value();
|
||||
for (auto &value : sequence_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
|
||||
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_dict);
|
||||
auto dict_value = value_dict->value();
|
||||
for (auto &iter : dict_value) {
|
||||
auto value = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
|
||||
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor);
|
||||
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
|
||||
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
|
||||
<< arg.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the front_node related tensor in the input_tensor.
|
||||
void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const KernelWithIndex &front_node,
|
||||
std::vector<tensor::TensorPtr> *input_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node.first);
|
||||
if (iter == parameters.end()) {
|
||||
(void)((*input_tensor).emplace_back(nullptr));
|
||||
|
@ -209,6 +266,31 @@ void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters
|
|||
PushInputTensor(args[position], input_tensor, front_node.second);
|
||||
}
|
||||
|
||||
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
||||
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
||||
const size_t position = iter - parameters.begin();
|
||||
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
|
||||
// and there is no need to input a tensor.
|
||||
if (position >= args.size()) {
|
||||
MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size()
|
||||
<< ".";
|
||||
(void)input_tensor->emplace_back(nullptr);
|
||||
return;
|
||||
}
|
||||
ValuePtrList flatted_value_tuple_value;
|
||||
FlattenValue(args[position], &flatted_value_tuple_value);
|
||||
if (index >= flatted_value_tuple_value.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
|
||||
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
|
||||
}
|
||||
auto input = flatted_value_tuple_value[index];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto tensor_input = input->cast<tensor::TensorPtr>();
|
||||
input_tensor->push_back(tensor_input);
|
||||
}
|
||||
|
||||
void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
|
@ -334,6 +416,38 @@ bool EnablePyNativeSyncRunning() {
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
return ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
|
||||
}
|
||||
|
||||
std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
|
||||
const VectorRef &args) {
|
||||
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
|
||||
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
for (const auto &input_node : kernel_graph->input_nodes()) {
|
||||
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
|
||||
if (element_pair.first) {
|
||||
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor);
|
||||
} else {
|
||||
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
|
||||
PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor);
|
||||
}
|
||||
}
|
||||
(void)input_tensors.emplace_back(input_tensor);
|
||||
}
|
||||
|
||||
// Input tensors of the control node.
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
// Get inputs of control node which come from the host actor.
|
||||
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
PushTensor(args, origin_parameters, parameter, &input_tensor);
|
||||
}
|
||||
(void)input_tensors.emplace_back(input_tensor);
|
||||
|
||||
return input_tensors;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
|
||||
|
@ -470,7 +584,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
PROF_START(compile_func_graph);
|
||||
auto root_graph = WrapPrimitives(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
root_graph_ = root_graph.get();
|
||||
root_graph_ = root_graph;
|
||||
// Register a summary callback function, which is called in the final stages of summary.
|
||||
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
|
||||
|
@ -504,8 +618,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
// Construct the graph compiler info.
|
||||
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
||||
if (real_execution_mode_ == kGraphMode && graph_compiler_info->graphs_.size() != 0) {
|
||||
if (real_execution_mode_ == kGraphMode && !graph_compiler_info->graphs_.empty()) {
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
ParseControlNodes(*graph_compiler_info);
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
}
|
||||
|
@ -815,84 +930,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) {
|
||||
for (size_t i = 0; i < value.size(); ++i) {
|
||||
auto value_element = value[i];
|
||||
MS_EXCEPTION_IF_NULL(value_element);
|
||||
if (utils::isa<tensor::TensorPtr>(value_element)) {
|
||||
(void)flatted_value->emplace_back(value_element);
|
||||
} else if (utils::isa<ValueTuplePtr>(value_element)) {
|
||||
auto value_tuple_element = value_element->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple_element);
|
||||
FlatValueTupleValue(value_tuple_element->value(), flatted_value);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) {
|
||||
if (utils::isa<ValueSequencePtr>(arg)) {
|
||||
auto value_sequence = utils::cast<ValueSequencePtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||
auto sequence_value = value_sequence->value();
|
||||
for (auto &value : sequence_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else if (utils::isa<ValueDictionaryPtr>(arg)) {
|
||||
auto value_dict = utils::cast<ValueDictionaryPtr>(arg);
|
||||
MS_EXCEPTION_IF_NULL(value_dict);
|
||||
auto dict_value = value_dict->value();
|
||||
for (auto &iter : dict_value) {
|
||||
auto value = iter.second;
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
(void)flatted_value->emplace_back(value);
|
||||
} else {
|
||||
FlattenValue(value, flatted_value);
|
||||
}
|
||||
}
|
||||
} else if (utils::isa<tensor::CSRTensorPtr>(arg)) {
|
||||
auto csr_tensor = utils::cast<tensor::CSRTensorPtr>(arg);
|
||||
for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) {
|
||||
(void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is "
|
||||
<< arg.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
|
||||
size_t index, std::vector<tensor::TensorPtr> *input_tensor) {
|
||||
const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
|
||||
const size_t position = iter - parameters.begin();
|
||||
// If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
|
||||
// and there is no need to input a tensor.
|
||||
if (position >= args.size()) {
|
||||
MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size()
|
||||
<< ".";
|
||||
(void)input_tensor->emplace_back(nullptr);
|
||||
return;
|
||||
}
|
||||
ValuePtrList flatted_value_tuple_value;
|
||||
FlattenValue(args[position], &flatted_value_tuple_value);
|
||||
if (index >= flatted_value_tuple_value.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index
|
||||
<< " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << ".";
|
||||
}
|
||||
auto input = flatted_value_tuple_value[index];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto tensor_input = input->cast<tensor::TensorPtr>();
|
||||
input_tensor->push_back(tensor_input);
|
||||
}
|
||||
|
||||
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) {
|
||||
void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) {
|
||||
bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
|
||||
distributed::recovery::RecoveryContext::GetInstance()->need_reset());
|
||||
if (need_contruct_output) {
|
||||
|
@ -911,8 +949,9 @@ void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *ou
|
|||
}
|
||||
|
||||
void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
const VectorRef &args, VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
auto inputs = GetRunGraphInputs(graph_compiler_info, args);
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
auto graphs = graph_compiler_info.graphs_;
|
||||
if (graphs.size() > inputs.size()) {
|
||||
|
@ -927,15 +966,22 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
|
|||
const auto &graph = graphs[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
graph->set_flag(kFlagPyNativeRunInGraph, true);
|
||||
// Get real format from input tensor before compile graph
|
||||
pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs[i]);
|
||||
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front());
|
||||
|
||||
// The size of control_nodes is at least 1 since there is return node in the graph.
|
||||
if (control_nodes_.size() == 1) {
|
||||
MS_LOG(INFO) << "Replace parameter format";
|
||||
// Input tensor is null if there is control-flow in graph.
|
||||
// Need to get tensor after ParseControlNodes.
|
||||
pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs.at(i));
|
||||
}
|
||||
graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.at(i));
|
||||
pynative::GraphAdapter::RemoveUnusedValueNodes(graph);
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
|
||||
// Clear front outputs after the outputs is cached.
|
||||
graph->set_front_outputs({});
|
||||
}
|
||||
|
||||
ParseControlNodes(graph_compiler_info);
|
||||
actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
// Multithreading can cause spikes in memory usage and performance fluctuations
|
||||
|
@ -952,9 +998,11 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
|
|||
pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph);
|
||||
}
|
||||
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors = GetRunGraphInputs(graph_compiler_info, args);
|
||||
|
||||
// Release GIL and run actor DAG.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs);
|
||||
runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, input_tensors);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
graph_compiler_->Summary(graph_compiler_info.graphs_);
|
||||
|
@ -977,12 +1025,15 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom
|
|||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args,
|
||||
VectorRef *outputs) {
|
||||
WaitTaskFinish();
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
auto &op_executor = runtime::OpExecutor::GetInstance();
|
||||
op_executor.Register([this]() { BatchBuildCallback(); });
|
||||
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
const auto &graphs = graph_compiler_info.graphs_;
|
||||
auto inputs = GetRunGraphInputs(graph_compiler_info, args);
|
||||
for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
|
||||
const auto &graph = graphs[graph_index];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1043,13 +1094,12 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
}
|
||||
|
||||
void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
|
||||
VectorRef *outputs) {
|
||||
const VectorRef &args, VectorRef *outputs) {
|
||||
bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); });
|
||||
if (contain_cut_graph) {
|
||||
// Python API will be called in cut_graph, so we cannot release gil here.
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
RunGraphBySingleOp(graph_compiler_info, args, outputs);
|
||||
} else {
|
||||
bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
|
||||
[](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); });
|
||||
|
@ -1061,10 +1111,12 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph
|
|||
|
||||
// TODO(caifubi): Need to support condition: 1. Dynamic shape. 2. AutoParallel.
|
||||
MS_EXCEPTION_IF_NULL(root_graph_);
|
||||
if (root_graph_->has_flag(kFlagIsDynamicStructure) || is_dynamic_shape || is_parallel) {
|
||||
RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
|
||||
if (root_graph_->has_flag(kFlagIsDynamicStructure) ||
|
||||
// `ms_function + dynamic_shape` is already supported, so there is no need to execute RunGraphBySingleOp.
|
||||
(is_dynamic_shape && root_graph_->has_flag(kFlagIsPynativeBpropGraph)) || is_parallel) {
|
||||
RunGraphBySingleOp(graph_compiler_info, args, outputs);
|
||||
} else {
|
||||
RunGraphByActors(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
|
@ -1094,46 +1146,18 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(graph_iter->second);
|
||||
const auto &graph_compiler_info = *(graph_iter->second);
|
||||
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
|
||||
|
||||
// For pynative and graph mix execution.
|
||||
WaitTaskFinish();
|
||||
|
||||
// Transform args to input tensors.
|
||||
// Input tensors of the graph.
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
|
||||
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
for (const auto &input_node : kernel_graph->input_nodes()) {
|
||||
auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
|
||||
if (element_pair.first) {
|
||||
PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor);
|
||||
} else {
|
||||
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
|
||||
PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor);
|
||||
}
|
||||
}
|
||||
(void)input_tensors.emplace_back(input_tensor);
|
||||
}
|
||||
|
||||
// Input tensors of the control node.
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
// Get inputs of control node which come from the host actor.
|
||||
const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
|
||||
for (const auto ¶meter : control_node_parameters) {
|
||||
PushTensor(args, origin_parameters, parameter, &input_tensor);
|
||||
}
|
||||
(void)input_tensors.emplace_back(input_tensor);
|
||||
// Run in the pynative mode.
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
|
||||
if (real_execution_mode_ == kPynativeMode) {
|
||||
RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs);
|
||||
RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
|
||||
// Release python gil.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
// Run actor DAG.
|
||||
|
@ -1298,22 +1322,7 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
++graph_index;
|
||||
}
|
||||
|
||||
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
|
||||
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
|
||||
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
|
||||
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
|
||||
std::vector<KernelGraphPtr> kernel_graphs;
|
||||
for (const auto &graph_id : sub_kernel_graphs_ids) {
|
||||
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
(void)kernel_graphs.emplace_back(kernel_graph);
|
||||
}
|
||||
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
|
||||
}
|
||||
}
|
||||
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs);
|
||||
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
const auto &root_output =
|
||||
|
@ -1635,5 +1644,25 @@ void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &ou
|
|||
outputs->emplace_back(output_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void MindRTBackend::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
|
||||
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
|
||||
for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
|
||||
const auto &func_graph = func_graph_to_kernel_graph_ids.first;
|
||||
for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
|
||||
std::vector<KernelGraphPtr> kernel_graphs;
|
||||
for (const auto &graph_id : sub_kernel_graphs_ids) {
|
||||
const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
(void)kernel_graphs.emplace_back(kernel_graph);
|
||||
}
|
||||
(void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
|
||||
}
|
||||
}
|
||||
|
||||
graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
|
||||
graph_compile_info.device_contexts_, root_graph_,
|
||||
func_graph_to_kernel_graphs);
|
||||
}
|
||||
} // namespace compile
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -143,7 +143,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
// Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode.
|
||||
void CompileSingleOpGraphs(const std::vector<std::shared_ptr<runtime::OpBuildTask>> &build_tasks);
|
||||
|
||||
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph);
|
||||
void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph);
|
||||
|
||||
// Restore the outputs tuple by the origin funcGraph output node and output tensors.
|
||||
void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
|
||||
|
@ -160,6 +160,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
const std::vector<tensor::TensorPtr> *input_tensors,
|
||||
bool need_erase);
|
||||
|
||||
void ParseControlNodes(const GraphCompilerInfo &graph_compile_info);
|
||||
|
||||
// In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
|
||||
// so the latest single op cache should be erased when cache list size exceeds threshold value.
|
||||
void EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info, const KernelGraphPtr &graph);
|
||||
|
@ -176,14 +178,13 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
OpRunInfo *op_run_info);
|
||||
|
||||
void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &input_tensors, VectorRef *outputs);
|
||||
const VectorRef &args, VectorRef *outputs);
|
||||
// Split complete kernel graph to single op graph in PyNative back
|
||||
// propagation, then compile and run single op graph.
|
||||
void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
void RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args, VectorRef *outputs);
|
||||
|
||||
void RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
|
||||
const VectorRef &args, VectorRef *outputs);
|
||||
|
||||
void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs);
|
||||
|
||||
|
@ -209,7 +210,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend {
|
|||
// Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode.
|
||||
std::map<std::string, size_t> forward_op_output_tensor_id_;
|
||||
|
||||
FuncGraph *root_graph_;
|
||||
FuncGraphPtr root_graph_;
|
||||
GraphPartitionPtr graph_partition_;
|
||||
std::shared_ptr<GraphCompiler> graph_compiler_;
|
||||
std::string device_name_;
|
||||
|
|
|
@ -955,16 +955,11 @@ void OriginSetRunMode(const ResourcePtr &resource) {
|
|||
}
|
||||
}
|
||||
|
||||
void SetRunMode(const ResourcePtr &resource, bool pynative_switch_to_graph_mode) {
|
||||
void SetRunMode(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) && common::GetEnv("DISABLE_ASCEND_MINDRT") != "1") {
|
||||
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
|
||||
if (pynative_switch_to_graph_mode) {
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
|
||||
}
|
||||
SetRunMode(resource->func_graph(), resource->GetResult(kBackend).cast<compile::BackendPtr>().get());
|
||||
} else {
|
||||
OriginSetRunMode(resource);
|
||||
|
@ -995,22 +990,14 @@ bool TaskEmitAction(const ResourcePtr &resource) {
|
|||
MS_LOG(WARNING) << "Multi device target is detected, CPU data is dumped in rank_0 directory";
|
||||
}
|
||||
DisableMindRT(resource);
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
bool pynative_switch_to_graph_mode =
|
||||
context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
|
||||
(!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) &&
|
||||
!is_parallel;
|
||||
SetRunMode(resource, pynative_switch_to_graph_mode);
|
||||
|
||||
SetRunMode(resource);
|
||||
auto bc_ptr = resource->GetResult(kBackend).cast<compile::BackendPtr>();
|
||||
MS_EXCEPTION_IF_NULL(bc_ptr);
|
||||
std::string backend = context_ptr->backend_policy();
|
||||
// The graph compiling of mindRT.
|
||||
if ((backend == kMsConvert) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
|
||||
TaskEmitActionForMindRT(resource);
|
||||
if (pynative_switch_to_graph_mode) {
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -455,7 +455,7 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph)
|
|||
AllocateGraphMemory(NOT_NULL(graph));
|
||||
LoadModel(NOT_NULL(graph));
|
||||
AssignOutputNopNodeDeviceAddress(graph);
|
||||
} else if (graph->is_dynamic_shape() && IsGraphMode()) {
|
||||
} else if (graph->is_dynamic_shape() && (IsGraphMode() || graph->has_flag(kFlagPyNativeRunInGraph))) {
|
||||
device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_);
|
||||
SetAtomicCleanToNodes(graph, node_atomics_); // graph mode may can do it too, instead of update execorder
|
||||
opt::DynamicShapeConvertPass(graph);
|
||||
|
|
|
@ -325,7 +325,8 @@ void CPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
|||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (kernel_graph->is_dynamic_shape() && ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
if (kernel_graph->is_dynamic_shape() && (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
|
||||
kernel_graph->has_flag(kFlagPyNativeRunInGraph))) {
|
||||
opt::DynamicShapeConvertPass(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -225,7 +225,8 @@ void GPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const {
|
|||
if (!kernel_graph->is_from_single_op()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (kernel_graph->is_dynamic_shape() && ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
if (kernel_graph->is_dynamic_shape() && (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode ||
|
||||
kernel_graph->has_flag(kFlagPyNativeRunInGraph))) {
|
||||
opt::DynamicShapeConvertPass(kernel_graph);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -562,10 +562,14 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
|
||||
session_->DumpGraphs({graph});
|
||||
|
||||
// The graph is not compiled yet in PyNative Mode.
|
||||
// Need to cache output latter when the graph is compiled.
|
||||
if (!run_in_pynative) {
|
||||
// Cache the backend graph output nodes to front nodes with output index.
|
||||
auto backend_node = graph->output();
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
|
|
@ -44,13 +44,12 @@ def test_adaptive_max_pool2d():
|
|||
x = np.random.randn(32, 64, 128, 128).astype(np.float32)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
adaptive_max_pool2d = Net(64, True)
|
||||
output1, argmax1 = adaptive_max_pool2d(Tensor(x))
|
||||
output1, _ = adaptive_max_pool2d(Tensor(x))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
adaptive_max_pool2d = Net(64, True)
|
||||
output2, argmax2 = adaptive_max_pool2d(Tensor(x))
|
||||
output2, _ = adaptive_max_pool2d(Tensor(x))
|
||||
assert (output1.asnumpy() == output2.asnumpy()).all()
|
||||
assert (argmax1.asnumpy() == argmax2.asnumpy()).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore import context
|
|||
from mindspore.ops import GradOperation
|
||||
from mindspore.common import ParameterTuple
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore import ops as P
|
||||
|
||||
|
||||
def forward_pre_hook_fn_bn(cell_id, inp):
|
||||
|
@ -49,7 +50,7 @@ def forward_pre_hook_fn_multi_add(cell_id, inp):
|
|||
|
||||
|
||||
def forward_hook_fn_conv(cell_id, inp, outp):
|
||||
out = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")(outp)
|
||||
out = P.Log()(outp)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -101,6 +102,10 @@ class SingleNet(nn.Cell):
|
|||
class SingleNetInConstruct(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SingleNetInConstruct, self).__init__()
|
||||
self.handle1 = None
|
||||
self.handle2 = None
|
||||
self.handle3 = None
|
||||
self.handle4 = None
|
||||
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
@ -262,13 +267,14 @@ class CompareMultiNet1(nn.Cell):
|
|||
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||
self.relu = nn.ReLU()
|
||||
self.log = P.Log()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.mul(x + x, x * y)
|
||||
x = self.conv(x)
|
||||
x = self.log(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
x = self.mul(x, x)
|
||||
y = self.relu(x)
|
||||
x = self.mul(y, x)
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
|
@ -280,10 +286,11 @@ class CompareMultiNet2(nn.Cell):
|
|||
self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
|
||||
self.relu = nn.ReLU()
|
||||
self.log = P.Log()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.mul(x, y)
|
||||
x = self.conv(x)
|
||||
x = self.log(x)
|
||||
x = self.bn(x)
|
||||
x = self.mul(x, x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue