forked from mindspore-Ecosystem/mindspore
!30277 PyNative ms_function compile and run in GRAPH_MODE
Merge pull request !30277 from caifubi/master-pynative-run-in-graph
This commit is contained in:
commit
403fbc2a21
|
@ -24,6 +24,7 @@
|
|||
#include "runtime/op_builder/op_lazy_builder.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
|
@ -451,6 +452,15 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
real_execution_mode_ = ms_execution_mode_;
|
||||
|
||||
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
|
||||
if (ms_execution_mode_ == kPynativeMode &&
|
||||
(!func_graph->is_bprop() || func_graph->manager()->func_graphs().size() > 1)) {
|
||||
real_execution_mode_ = kGraphMode;
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
pipeline::SetRunMode(func_graph, this);
|
||||
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
|
||||
}
|
||||
|
||||
// Compile root graph.
|
||||
graph_id_to_device_context_.clear();
|
||||
func_graph_to_kernel_graph_ids_.clear();
|
||||
|
@ -469,16 +479,20 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// Construct the graph compiler info.
|
||||
auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
|
||||
|
||||
if (real_execution_mode_ == kGraphMode) {
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
||||
if (real_execution_mode_ == kGraphMode && graph_compiler_info->graphs_.size() != 0) {
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info);
|
||||
const ActorInfo &actor_info = graph_compiler_info->name_;
|
||||
(void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
|
||||
PROF_END(compile_func_graph);
|
||||
|
||||
if (ms_execution_mode_ != real_execution_mode_) {
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
|
||||
<< ", produce actor: " << actor_info;
|
||||
return actor_info;
|
||||
|
@ -507,12 +521,12 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// Foreach the segments to compile graph.
|
||||
for (const auto &segment : new_segments) {
|
||||
CompileGraph(segment, contain_multi_target, func_graph->is_bprop());
|
||||
CompileGraph(segment);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative) {
|
||||
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment) {
|
||||
MS_EXCEPTION_IF_NULL(segment);
|
||||
// Compile the normal nodes, which doesn't contain the cut node.
|
||||
if (segment->nodes_.size() == 0) {
|
||||
|
@ -537,19 +551,9 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
|
|||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
|
||||
if ((contain_multi_target || !run_in_pynative) && ms_execution_mode_ == kPynativeMode) {
|
||||
real_execution_mode_ = kGraphMode;
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
|
||||
}
|
||||
|
||||
// Compile graph.
|
||||
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context, run_in_pynative);
|
||||
|
||||
if (ms_execution_mode_ != real_execution_mode_) {
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
|
||||
}
|
||||
auto graph_id =
|
||||
graph_compiler_->CompileGraph(segment, outputs, device_context, real_execution_mode_ == kPynativeMode);
|
||||
|
||||
graph_id_to_device_context_[graph_id] = device_context;
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ class MindRTBackend : public Backend {
|
|||
bool CompileGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Compile the kernel graph by the segment which is from the function graph partition.
|
||||
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative);
|
||||
void CompileGraph(const GraphSegmentPtr &segment);
|
||||
|
||||
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
|
||||
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||
|
|
|
@ -786,13 +786,10 @@ bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &ta
|
|||
return false;
|
||||
}
|
||||
|
||||
void SetRunMode(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto backend_ptr = res->GetResult(kBackend).cast<compile::BackendPtr>();
|
||||
MS_EXCEPTION_IF_NULL(backend_ptr);
|
||||
std::string backend = context_ptr->backend_policy();
|
||||
auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
|
||||
|
@ -905,7 +902,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) && common::GetEnv("DISABLE_ASCEND_MINDRT") != "1") {
|
||||
SetRunMode(res);
|
||||
SetRunMode(res->func_graph(), res->GetResult(kBackend).cast<compile::BackendPtr>().get());
|
||||
} else {
|
||||
OriginSetRunMode(res);
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "backend/graph_compiler/segment_runner.h"
|
||||
#include "backend/graph_compiler/backend.h"
|
||||
|
||||
namespace mindspore {
|
||||
extern const char kMsConvert[];
|
||||
|
@ -60,6 +61,7 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_
|
|||
const abstract::AnalysisContextPtr &context);
|
||||
FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_spec);
|
||||
void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1706,6 +1706,11 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
// The graph output is from device tensor store.
|
||||
if (IsPersistentDeviceTensor(output_with_index.first)) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
|
||||
if (!AnfAlgo::OutputAddrExist(output_with_index.first, 0, false)) {
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
MS_LOG(WARNING) << output_with_index.first->DebugString() << " device address not exit";
|
||||
continue;
|
||||
}
|
||||
// In the scenario where the ValueTuple is expanded, the output_with_index.second may be incorrect, so use 0
|
||||
// as output_idx directly.
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, 0, false);
|
||||
|
|
|
@ -289,3 +289,98 @@ def test_pynative_ms_function_mix_execute():
|
|||
b = Tensor(2)
|
||||
output = net(a, b)
|
||||
assert output == 8
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_empty_graph():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: Empty ms_function graph.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, x, y):
|
||||
super().__init__()
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.relu = P.ReLU()
|
||||
|
||||
@ms_function
|
||||
def max(self):
|
||||
if self.x > self.y:
|
||||
return self.x
|
||||
return self.y
|
||||
|
||||
def construct(self):
|
||||
a = self.max()
|
||||
return self.relu(a)
|
||||
|
||||
net = Net(Tensor(5, ms.float32), Tensor(10, ms.float32))
|
||||
output = net()
|
||||
assert output.asnumpy() == 10
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_control_flow_if_break():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: PyNative ms_function with control flow.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, y, z):
|
||||
out = z
|
||||
for _ in range(5):
|
||||
if 2 * x < y:
|
||||
if 3 * x < y:
|
||||
out = self.add(out, out)
|
||||
x = x + 1
|
||||
out = self.relu(out)
|
||||
if x + 6 == y:
|
||||
break
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
x = Tensor(2, ms.int32)
|
||||
y = Tensor(10, ms.int32)
|
||||
z = Tensor(np.ones([4, 4, 4]), ms.float32)
|
||||
output = net(x, y, z)
|
||||
assert (output.asnumpy() == z.asnumpy() * 4).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_with_dynamic_shape():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: PyNative ms_function with dynamic shape.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
@ms_function()
|
||||
def test(x):
|
||||
return ms.numpy.unique(x, return_inverse=True)
|
||||
|
||||
x = Tensor([[1, 1, 2], [3, 3, 5]], ms.int32)
|
||||
output = test(x)
|
||||
assert (output[0].asnumpy() == np.array([1, 2, 3, 5])).all()
|
||||
|
|
Loading…
Reference in New Issue