!28983 ms_function needs to be compiled and executed in graph

Merge pull request !28983 from caifubi/master-pynative-add-bprop-flag
This commit is contained in:
i-robot 2022-01-22 11:17:09 +00:00 committed by Gitee
commit 3ed062fcfa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 80 additions and 22 deletions

View File

@ -2725,6 +2725,8 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
}
// Get bprop graph of top cell
auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
MS_EXCEPTION_IF_NULL(bprop_graph);
bprop_graph->set_is_bprop(true);
resource->set_func_graph(bprop_graph);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);

View File

@ -448,12 +448,7 @@ bool AscendDeviceAddress::SyncDeviceToDeviceWithSameFormatType(const ShapeVector
return false;
}
BindDevice();
auto ret_rt_memcpy = aclrtMemcpy(ptr_, size, src_ptr, size, ACL_MEMCPY_DEVICE_TO_DEVICE);
if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";
return false;
}
return true;
return AsyncDeviceToDevice(shape, size, type, src_ptr, format);
}
bool AscendDeviceAddress::SyncDeviceToDeviceWithDiffFormatType(const DeviceSync *src_device_addr) const {

View File

@ -345,7 +345,7 @@ void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_wit
GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
const DeviceContext *device_context) {
const DeviceContext *device_context, bool run_in_pynative) {
MS_EXCEPTION_IF_NULL(session_);
MS_EXCEPTION_IF_NULL(segment);
MS_LOG(INFO) << "Status record: start compile graph.";
@ -372,7 +372,17 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
session_->SetInputNodeUsage(graph, manager);
graph->SetOptimizerFlag();
auto graph_id = CompileGraphImpl(graph, device_context);
GraphId graph_id;
if (run_in_pynative) {
MS_EXCEPTION_IF_NULL(session_);
// Graphkernel not support pynative mode now, so when users open graphkernel in pynative mode
// should print a warning log to reminder users by using GetInstance func.
(void)graphkernel::GraphKernelFlags::GetInstance();
session_->InitAllBucket(graph, device_context);
graph_id = graph->graph_id();
} else {
graph_id = CompileGraphImpl(graph, device_context);
}
session_->DumpGraphs({graph});
@ -435,14 +445,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
MS_EXCEPTION_IF_NULL(device_context);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
// graphkernel not support pynative mode now, so when users open graphkernel
// in pynative mode should print a warning log to reminder users by using GetInstance func.
graphkernel::GraphKernelFlags::GetInstance();
MS_EXCEPTION_IF_NULL(session_);
session_->InitAllBucket(graph, device_context);
return graph->graph_id();
}
#ifdef ENABLE_DUMP_IR
bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);

View File

@ -98,7 +98,7 @@ class GraphCompiler {
// Construct kernel graph from anf nodes list and compile kernel graph in Graph mode,
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
const DeviceContext *device_context);
const DeviceContext *device_context, bool run_in_pynative = false);
// Construct kernel graph from function graph and compile kernel graph in Graph mode,
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.

View File

@ -31,6 +31,7 @@
#include "runtime/hardware/ascend/ascend_graph_optimization.h"
#include "backend/kernel_compiler/ascend_kernel_mod.h"
#include "backend/kernel_compiler/aicpu/aicpu_kernel_load.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
#include "runtime/device/ascend/ascend_bucket.h"
#include "common/util/error_manager/error_manager.h"
#include "runtime/device/ascend/ascend_memory_adapter.h"
@ -263,6 +264,9 @@ void AscendDeviceContext::Initialize() {
compute_stream_ = runtime_instance_->compute_stream();
communication_stream_ = runtime_instance_->communication_stream();
// Initialize tbe using HCCL rank_id
kernel::ascend::TbeKernelCompileManager::GetInstance().TbeInitialize();
initialized_ = true;
MS_LOG(INFO) << "Status record: Initialize success.";
}
@ -279,6 +283,7 @@ void AscendDeviceContext::Destroy() {
return;
}
MS_LOG(INFO) << "Status record: Destroy start...";
graph_event_.clear();
rank_id_ = 0;
if (runtime_instance_) {
// TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed
@ -550,6 +555,8 @@ bool AscendDeviceContext::ExecuteGraph(const KernelGraphPtr &graph) const {
const uint64_t kUSecondInSecond = 1000000;
bool ret = false;
if (graph->is_executing_sink()) {
InsertEventBeforeRunTask(graph);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
@ -870,6 +877,23 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(atomic_node));
}
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
if (!graph->is_executing_sink() || graph->is_dynamic_shape()) {
return;
}
MS_LOG(DEBUG) << "Insert event between PyNative and Graph";
MS_EXCEPTION_IF_NULL(runtime_instance_);
auto model_stream = runtime_instance_->GetModelStream(graph->graph_id());
auto compute_event = runtime_instance_->CreateDeviceEvent();
MS_EXCEPTION_IF_NULL(compute_event);
compute_event->set_wait_stream(model_stream);
compute_event->set_record_stream(compute_stream_);
compute_event->RecordEvent();
compute_event->WaitEvent();
graph_event_[graph->graph_id()] = compute_event;
}
MS_REGISTER_DEVICE(kAscendDevice, AscendDeviceContext);
} // namespace ascend
} // namespace device

View File

@ -141,6 +141,7 @@ class AscendDeviceContext : public DeviceContext {
bool PySyncRuning() const;
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
void GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const;
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
void ReportErrorMessage() const;
void ReportWarningMessage() const;
@ -166,6 +167,8 @@ class AscendDeviceContext : public DeviceContext {
// node_atomics_ will be cleaned up in CompileGraph.
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_persistent_cache_;
mutable std::set<CNodePtr> nop_op_to_memcpy_;
// Event for multi-stream
mutable std::map<uint32_t, std::shared_ptr<DeviceEvent>> graph_event_;
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
// output device address for these nodes, and the output device address is the same with input device address.
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;

View File

@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "runtime/device/device_address.h"
#include "runtime/device/bucket.h"
#include "runtime/hardware/collective/collective_communication_lib.h"

View File

@ -517,12 +517,12 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
// Foreach the segments to compile graph.
for (const auto &segment : new_segments) {
CompileGraph(segment, contain_multi_target);
CompileGraph(segment, contain_multi_target, func_graph->is_bprop());
}
return true;
}
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target) {
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative) {
MS_EXCEPTION_IF_NULL(segment);
// Compile the normal nodes, which doesn't contain the cut node.
if (segment->nodes_.size() == 0) {
@ -548,13 +548,14 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
if (contain_multi_target && ms_execution_mode_ == kPynativeMode) {
if ((contain_multi_target || !run_in_pynative) && ms_execution_mode_ == kPynativeMode) {
real_execution_mode_ = kGraphMode;
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
}
// Compile graph.
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context);
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context, run_in_pynative);
if (ms_execution_mode_ != real_execution_mode_) {
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
@ -905,6 +906,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
const auto &graph_compiler_info = *(graph_iter->second);
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
SyncLazyTasks();
// Transform args to input tensors.
// Input tensors of the graph.
std::vector<std::vector<tensor::TensorPtr>> input_tensors;

View File

@ -134,7 +134,7 @@ class MindRTBackend : public Backend {
bool CompileGraph(const FuncGraphPtr &func_graph);
// Compile the kernel graph by the segment which is from the function graph partition.
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target);
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative);
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,

View File

@ -261,3 +261,31 @@ def test_pynative_ms_function():
out_b = grad(net_b, params_b)(input_data)
assert np.allclose(out_a[0][0].asnumpy(), out_b[0][0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(out_a[1][0].asnumpy(), out_b[1][0].asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_pynative_ms_function_mix_execute():
"""
Feature: PyNative ms_function.
Description: Mixed execution of PyNative and ms_function.
Expectation: The calculation result is correct.
"""
class Net(nn.Cell):
@ms_function
def test_ms_function(self, x, y):
return x * y
def construct(self, x, y):
z = x * y
return self.test_ms_function(z, x)
net = Net()
a = Tensor(2)
b = Tensor(2)
output = net(a, b)
assert output == 8