!28983 ms_function needs to be compiled and executed in graph
Merge pull request !28983 from caifubi/master-pynative-add-bprop-flag
This commit is contained in:
commit
3ed062fcfa
|
@ -2725,6 +2725,8 @@ void GradExecutor::GradNetInner(py::object *ret, const prim::GradOperationPtr &g
|
|||
}
|
||||
// Get bprop graph of top cell
|
||||
auto bprop_graph = GetBpropGraph(grad, cell, w_args, p_args, size, args);
|
||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||
bprop_graph->set_is_bprop(true);
|
||||
resource->set_func_graph(bprop_graph);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
|
|
@ -448,12 +448,7 @@ bool AscendDeviceAddress::SyncDeviceToDeviceWithSameFormatType(const ShapeVector
|
|||
return false;
|
||||
}
|
||||
BindDevice();
|
||||
auto ret_rt_memcpy = aclrtMemcpy(ptr_, size, src_ptr, size, ACL_MEMCPY_DEVICE_TO_DEVICE);
|
||||
if (ret_rt_memcpy != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "SyncDeviceToDevice failed, rtMemcpy mem size [" << size << "], ret [" << ret_rt_memcpy << "]";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return AsyncDeviceToDevice(shape, size, type, src_ptr, format);
|
||||
}
|
||||
|
||||
bool AscendDeviceAddress::SyncDeviceToDeviceWithDiffFormatType(const DeviceSync *src_device_addr) const {
|
||||
|
|
|
@ -345,7 +345,7 @@ void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_wit
|
|||
GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
|
||||
|
||||
GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
|
||||
const DeviceContext *device_context) {
|
||||
const DeviceContext *device_context, bool run_in_pynative) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
MS_EXCEPTION_IF_NULL(segment);
|
||||
MS_LOG(INFO) << "Status record: start compile graph.";
|
||||
|
@ -372,7 +372,17 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod
|
|||
session_->SetInputNodeUsage(graph, manager);
|
||||
graph->SetOptimizerFlag();
|
||||
|
||||
auto graph_id = CompileGraphImpl(graph, device_context);
|
||||
GraphId graph_id;
|
||||
if (run_in_pynative) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
// Graphkernel not support pynative mode now, so when users open graphkernel in pynative mode
|
||||
// should print a warning log to reminder users by using GetInstance func.
|
||||
(void)graphkernel::GraphKernelFlags::GetInstance();
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
graph_id = graph->graph_id();
|
||||
} else {
|
||||
graph_id = CompileGraphImpl(graph, device_context);
|
||||
}
|
||||
|
||||
session_->DumpGraphs({graph});
|
||||
|
||||
|
@ -435,14 +445,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
const auto &ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
// graphkernel not support pynative mode now, so when users open graphkernel
|
||||
// in pynative mode should print a warning log to reminder users by using GetInstance func.
|
||||
graphkernel::GraphKernelFlags::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->InitAllBucket(graph, device_context);
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
|
|
|
@ -98,7 +98,7 @@ class GraphCompiler {
|
|||
// Construct kernel graph from anf nodes list and compile kernel graph in Graph mode,
|
||||
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
||||
GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
|
||||
const DeviceContext *device_context);
|
||||
const DeviceContext *device_context, bool run_in_pynative = false);
|
||||
|
||||
// Construct kernel graph from function graph and compile kernel graph in Graph mode,
|
||||
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "runtime/hardware/ascend/ascend_graph_optimization.h"
|
||||
#include "backend/kernel_compiler/ascend_kernel_mod.h"
|
||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_load.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
|
||||
#include "runtime/device/ascend/ascend_bucket.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
#include "runtime/device/ascend/ascend_memory_adapter.h"
|
||||
|
@ -263,6 +264,9 @@ void AscendDeviceContext::Initialize() {
|
|||
compute_stream_ = runtime_instance_->compute_stream();
|
||||
communication_stream_ = runtime_instance_->communication_stream();
|
||||
|
||||
// Initialize tbe using HCCL rank_id
|
||||
kernel::ascend::TbeKernelCompileManager::GetInstance().TbeInitialize();
|
||||
|
||||
initialized_ = true;
|
||||
MS_LOG(INFO) << "Status record: Initialize success.";
|
||||
}
|
||||
|
@ -279,6 +283,7 @@ void AscendDeviceContext::Destroy() {
|
|||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Status record: Destroy start...";
|
||||
graph_event_.clear();
|
||||
rank_id_ = 0;
|
||||
if (runtime_instance_) {
|
||||
// TODO(lzlang): Destroy runtime instance after fully support MindRT, otherwise runtime will be destructed
|
||||
|
@ -550,6 +555,8 @@ bool AscendDeviceContext::ExecuteGraph(const KernelGraphPtr &graph) const {
|
|||
const uint64_t kUSecondInSecond = 1000000;
|
||||
bool ret = false;
|
||||
if (graph->is_executing_sink()) {
|
||||
InsertEventBeforeRunTask(graph);
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
|
@ -870,6 +877,23 @@ bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vec
|
|||
return kernel_mod->Launch(atomic_inputs, {}, {}, GetKernelStream(atomic_node));
|
||||
}
|
||||
|
||||
void AscendDeviceContext::InsertEventBeforeRunTask(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!graph->is_executing_sink() || graph->is_dynamic_shape()) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Insert event between PyNative and Graph";
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance_);
|
||||
auto model_stream = runtime_instance_->GetModelStream(graph->graph_id());
|
||||
auto compute_event = runtime_instance_->CreateDeviceEvent();
|
||||
MS_EXCEPTION_IF_NULL(compute_event);
|
||||
compute_event->set_wait_stream(model_stream);
|
||||
compute_event->set_record_stream(compute_stream_);
|
||||
compute_event->RecordEvent();
|
||||
compute_event->WaitEvent();
|
||||
graph_event_[graph->graph_id()] = compute_event;
|
||||
}
|
||||
|
||||
MS_REGISTER_DEVICE(kAscendDevice, AscendDeviceContext);
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
|
|
|
@ -141,6 +141,7 @@ class AscendDeviceContext : public DeviceContext {
|
|||
bool PySyncRuning() const;
|
||||
bool MemoryCopyAsync(const CNodePtr &node, const vector<AddressPtr> &inputs, const vector<AddressPtr> &outputs) const;
|
||||
void GenKernelEvents(const NotNull<KernelGraphPtr> &root_graph) const;
|
||||
void InsertEventBeforeRunTask(const KernelGraphPtr &graph) const;
|
||||
|
||||
void ReportErrorMessage() const;
|
||||
void ReportWarningMessage() const;
|
||||
|
@ -166,6 +167,8 @@ class AscendDeviceContext : public DeviceContext {
|
|||
// node_atomics_ will be cleaned up in CompileGraph.
|
||||
mutable std::map<CNodePtr, std::vector<CNodePtr>> node_atomics_persistent_cache_;
|
||||
mutable std::set<CNodePtr> nop_op_to_memcpy_;
|
||||
// Event for multi-stream
|
||||
mutable std::map<uint32_t, std::shared_ptr<DeviceEvent>> graph_event_;
|
||||
// Some NOP nodes have be hide in execution order, it doesn't have output device address, this function creates
|
||||
// output device address for these nodes, and the output device address is the same with input device address.
|
||||
void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph) const;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "runtime/device/bucket.h"
|
||||
#include "runtime/hardware/collective/collective_communication_lib.h"
|
||||
|
|
|
@ -517,12 +517,12 @@ bool MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// Foreach the segments to compile graph.
|
||||
for (const auto &segment : new_segments) {
|
||||
CompileGraph(segment, contain_multi_target);
|
||||
CompileGraph(segment, contain_multi_target, func_graph->is_bprop());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target) {
|
||||
void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative) {
|
||||
MS_EXCEPTION_IF_NULL(segment);
|
||||
// Compile the normal nodes, which doesn't contain the cut node.
|
||||
if (segment->nodes_.size() == 0) {
|
||||
|
@ -548,13 +548,14 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment, bool contain_mu
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
|
||||
if (contain_multi_target && ms_execution_mode_ == kPynativeMode) {
|
||||
if ((contain_multi_target || !run_in_pynative) && ms_execution_mode_ == kPynativeMode) {
|
||||
real_execution_mode_ = kGraphMode;
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
||||
MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE";
|
||||
}
|
||||
|
||||
// Compile graph.
|
||||
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context);
|
||||
auto graph_id = graph_compiler_->CompileGraph(segment, outputs, device_context, run_in_pynative);
|
||||
|
||||
if (ms_execution_mode_ != real_execution_mode_) {
|
||||
context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
|
||||
|
@ -905,6 +906,8 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
const auto &graph_compiler_info = *(graph_iter->second);
|
||||
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
|
||||
|
||||
SyncLazyTasks();
|
||||
|
||||
// Transform args to input tensors.
|
||||
// Input tensors of the graph.
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
|
||||
|
|
|
@ -134,7 +134,7 @@ class MindRTBackend : public Backend {
|
|||
bool CompileGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// Compile the kernel graph by the segment which is from the function graph partition.
|
||||
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target);
|
||||
void CompileGraph(const GraphSegmentPtr &segment, bool contain_multi_target, bool run_in_pynative);
|
||||
|
||||
// CreateKernel, Transform and Schedule have not been finished when LazyBuild is enabled in PyNative mode.
|
||||
void CompileSingleOpGraph(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||
|
|
|
@ -261,3 +261,31 @@ def test_pynative_ms_function():
|
|||
out_b = grad(net_b, params_b)(input_data)
|
||||
assert np.allclose(out_a[0][0].asnumpy(), out_b[0][0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(out_a[1][0].asnumpy(), out_b[1][0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_ms_function_mix_execute():
|
||||
"""
|
||||
Feature: PyNative ms_function.
|
||||
Description: Mixed execution of PyNative and ms_function.
|
||||
Expectation: The calculation result is correct.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
@ms_function
|
||||
def test_ms_function(self, x, y):
|
||||
return x * y
|
||||
|
||||
def construct(self, x, y):
|
||||
z = x * y
|
||||
return self.test_ms_function(z, x)
|
||||
|
||||
net = Net()
|
||||
a = Tensor(2)
|
||||
b = Tensor(2)
|
||||
output = net(a, b)
|
||||
assert output == 8
|
||||
|
|
Loading…
Reference in New Issue