!30277 PyNative ms_function compile and run in GRAPH_MODE

Merge pull request !30277 from caifubi/master-pynative-run-in-graph
This commit is contained in:
i-robot 2022-02-22 09:28:50 +00:00 committed by Gitee
commit 403fbc2a21
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 126 additions and 23 deletions

View File

@ -24,6 +24,7 @@
#include "runtime/op_builder/op_lazy_builder.h"
#include "backend/common/optimizer/helper.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/action.h"
#include "pipeline/jit/parse/data_converter.h"
#include "ir/anf.h"
#include "pybind_api/ir/base_ref_py.h"
@ -451,6 +452,15 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
real_execution_mode_ = ms_execution_mode_;
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
if (ms_execution_mode_ == kPynativeMode &&
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1)) {
real_execution_mode_ = kGraphMode;
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
pipeline::SetRunMode(func_graph, this);
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
}
// Compile root graph.
graph_id_to_device_context_.clear();
func_graph_to_kernel_graph_ids_.clear();
@ -469,16 +479,20 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
// Construct the graph compiler info.
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
if (real_execution_mode_ == kGraphMode) {
MS_EXCEPTION_IF_NULL(graph_compiler_info);
if (real_execution_mode_ == kGraphMode && graph_compiler_info->graphs_.size() != 0) {
// Transform graph to actor DAG, and schedule the actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
}
MS_EXCEPTION_IF_NULL(graph_compiler_info);
const ActorInfo &actor_info = graph_compiler_info->name_;
(void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
PROF_END(compile_func_graph);
if (ms_execution_mode_ != real_execution_mode_) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
}
MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
<< ", produce actor: " << actor_info;
return actor_info;
@ -507,12 +521,12 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
// Foreach the segments to compile graph.
for (const auto &segment : new_segments) {
CompileGraph(segment, contain_multi_target, func_graph->is_bprop());
CompileGraph(segment);
}
return true;
}
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative) {
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment) {
MS_EXCEPTION_IF_NULL(segment);
// Compile the normal nodes, which doesn't contain the cut node.
if (segment->nodes_.size() == 0) {
@ -537,19 +551,9 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if ((contain_multi_target || !run_in_pynative) && ms_execution_mode_ == kPynativeMode) {
real_execution_mode_ = kGraphMode;
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
}
// Compile graph.
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context, run_in_pynative);
if (ms_execution_mode_ != real_execution_mode_) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
}
auto graph_id =
graph_compiler_->CompileGraph(segment, outputs, device_context, real_execution_mode_ == kPynativeMode);
graph_id_to_device_context_[graph_id] = device_context;

View File

@ -134,7 +134,7 @@ class MindRTBackend : public Backend {
bool CompileGraph(const FuncGraphPtr &func_graph);
// Compile the kernel graph by the segment which is from the function graph partition.
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative);
void CompileGraph(const GraphSegmentPtr &segment);
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,

View File

@ -786,13 +786,10 @@ bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &ta
return false;
}
void SetRunMode(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto backend_ptr = res->GetResult(kBackend).cast<compile::BackendPtr>();
MS_EXCEPTION_IF_NULL(backend_ptr);
std::string backend = context_ptr->backend_policy();
auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
@ -905,7 +902,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) && common::GetEnv("DISABLE_ASCEND_MINDRT") != "1") {
SetRunMode(res);
SetRunMode(res->func_graph(), res->GetResult(kBackend).cast<compile::BackendPtr>().get());
} else {
OriginSetRunMode(res);
}

View File

@ -23,6 +23,7 @@
#include <string>
#include "pipeline/jit/resource.h"
#include "backend/graph_compiler/segment_runner.h"
#include "backend/graph_compiler/backend.h"
namespace mindspore {
extern const char kMsConvert[];
@ -60,6 +61,7 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_
const abstract::AnalysisContextPtr &context);
FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
const abstract::AbstractBasePtrList &args_spec);
void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr);
} // namespace pipeline
} // namespace mindspore

View File

@ -1706,6 +1706,11 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
// The graph output is from device tensor store.
if (IsPersistentDeviceTensor(output_with_index.first)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
MS_EXCEPTION_IF_NULL(output_with_index.first);
MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
continue;
}
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
// as output_idx directly.
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);

View File

@ -289,3 +289,98 @@ def test_pynative_ms_function_mix_execute():
b = Tensor(2)
output = net(a, b)
assert output == 8
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_empty_graph():
"""
Feature: PyNative ms_function.
Description: Empty ms_function graph.
Expectation: The calculation result is correct.
"""
class Net(nn.Cell):
def __init__(self, x, y):
super().__init__()
self.x = x
self.y = y
self.relu = P.ReLU()
@ms_function
def max(self):
if self.x > self.y:
return self.x
return self.y
def construct(self):
a = self.max()
return self.relu(a)
net = Net(Tensor(5, ms.float32), Tensor(10, ms.float32))
output = net()
assert output.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_control_flow_if_break():
"""
Feature: PyNative ms_function.
Description: PyNative ms_function with control flow.
Expectation: The calculation result is correct.
"""
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.add = P.TensorAdd()
@ms_function
def construct(self, x, y, z):
out = z
for _ in range(5):
if 2 * x < y:
if 3 * x < y:
out = self.add(out, out)
x = x + 1
out = self.relu(out)
if x + 6 == y:
break
out = self.relu(out)
return out
net = Net()
x = Tensor(2, ms.int32)
y = Tensor(10, ms.int32)
z = Tensor(np.ones([4, 4, 4]), ms.float32)
output = net(x, y, z)
assert (output.asnumpy() == z.asnumpy() * 4).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_with_dynamic_shape():
"""
Feature: PyNative ms_function.
Description: PyNative ms_function with dynamic shape.
Expectation: The calculation result is correct.
"""
@ms_function()
def test(x):
return ms.numpy.unique(x, return_inverse=True)
x = Tensor([[1, 1, 2], [3, 3, 5]], ms.int32)
output = test(x)
assert (output[0].asnumpy() == np.array([1, 2, 3, 5])).all()