forked from mindspore-Ecosystem/mindspore
!16628 Unify runtime for PyNative distributed mode
From: @zyli2020 Reviewed-by: @limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
620ba53725
|
@ -1590,7 +1590,18 @@ void AscendSession::SyncStream() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
||||||
return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
|
auto bucket = std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
|
||||||
|
|
||||||
|
auto kernel_runtime = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
||||||
|
auto compute_stream = kernel_runtime->compute_stream();
|
||||||
|
auto communication_stream = kernel_runtime->communication_stream();
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_stream);
|
||||||
|
MS_EXCEPTION_IF_NULL(communication_stream);
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(bucket);
|
||||||
|
bucket->Init({compute_stream}, {communication_stream});
|
||||||
|
return bucket;
|
||||||
}
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -610,7 +610,18 @@ void GPUSession::SyncStream() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
||||||
return std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size);
|
auto bucket = std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size);
|
||||||
|
|
||||||
|
auto kernel_runtime = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
||||||
|
auto compute_stream = kernel_runtime->compute_stream();
|
||||||
|
auto communication_stream = kernel_runtime->communication_stream();
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_stream);
|
||||||
|
MS_EXCEPTION_IF_NULL(communication_stream);
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(bucket);
|
||||||
|
bucket->Init({compute_stream}, {communication_stream});
|
||||||
|
return bucket;
|
||||||
}
|
}
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace session
|
} // namespace session
|
||||||
|
|
|
@ -2305,7 +2305,7 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
|
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
|
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
@ -2331,8 +2331,12 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
|
||||||
uint32_t bucket_id = 0;
|
uint32_t bucket_id = 0;
|
||||||
for (auto bucket_size : bucket_size_list) {
|
for (auto bucket_size : bucket_size_list) {
|
||||||
MS_LOG(INFO) << "Create new bucket:" << bucket_id;
|
MS_LOG(INFO) << "Create new bucket:" << bucket_id;
|
||||||
auto bucket = CreateBucket(bucket_id++, bucket_size);
|
std::shared_ptr<device::Bucket> bucket = nullptr;
|
||||||
bucket->Init();
|
if (device_context != nullptr) {
|
||||||
|
bucket = device_context->CreateBucket(bucket_id++, bucket_size);
|
||||||
|
} else {
|
||||||
|
bucket = CreateBucket(bucket_id++, bucket_size);
|
||||||
|
}
|
||||||
bucket_list.emplace_back(bucket);
|
bucket_list.emplace_back(bucket);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@
|
||||||
#if !defined(_WIN32) && !defined(_WIN64)
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
#include "debug/debugger/debugger.h"
|
#include "debug/debugger/debugger.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
@ -255,7 +256,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
||||||
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
||||||
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
|
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
|
||||||
void InitAllBucket(const KernelGraphPtr &graph);
|
void InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context = nullptr);
|
||||||
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
|
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
|
||||||
void ClearAllBucket(const GraphId &graph_id);
|
void ClearAllBucket(const GraphId &graph_id);
|
||||||
std::vector<uint32_t> GetAllReduceSplitIndex();
|
std::vector<uint32_t> GetAllReduceSplitIndex();
|
||||||
|
|
|
@ -184,14 +184,18 @@ std::shared_ptr<LaunchKernel> AscendBucket::CreateLaunchAtomicClean() {
|
||||||
return launch_atomic_clean;
|
return launch_atomic_clean;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendBucket::Init() {
|
void AscendBucket::Init(const std::vector<void *> &compute_streams, const std::vector<void *> &communication_streams) {
|
||||||
pre_event_ = std::make_shared<AscendEvent>();
|
pre_event_ = std::make_shared<AscendEvent>();
|
||||||
post_event_ = std::make_shared<AscendEvent>();
|
post_event_ = std::make_shared<AscendEvent>();
|
||||||
|
|
||||||
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
if (!compute_streams.empty()) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
compute_stream_ = compute_streams.front();
|
||||||
compute_stream_ = kernel_runtime->compute_stream();
|
}
|
||||||
stream_ = kernel_runtime->communication_stream();
|
if (!communication_streams.empty()) {
|
||||||
|
stream_ = communication_streams.front();
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_stream_);
|
||||||
|
MS_EXCEPTION_IF_NULL(stream_);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(pre_event_);
|
MS_EXCEPTION_IF_NULL(pre_event_);
|
||||||
MS_EXCEPTION_IF_NULL(post_event_);
|
MS_EXCEPTION_IF_NULL(post_event_);
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include "runtime/device/bucket.h"
|
#include "runtime/device/bucket.h"
|
||||||
|
|
||||||
namespace mindspore::device::ascend {
|
namespace mindspore::device::ascend {
|
||||||
|
@ -26,7 +27,7 @@ class AscendBucket : public Bucket {
|
||||||
AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {}
|
AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {}
|
||||||
~AscendBucket() override = default;
|
~AscendBucket() override = default;
|
||||||
|
|
||||||
void Init() override;
|
void Init(const std::vector<void *> &compute_streams, const std::vector<void *> &communication_streams) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void AllocateAllReduceAddr() override;
|
void AllocateAllReduceAddr() override;
|
||||||
|
|
|
@ -50,7 +50,7 @@ class Bucket {
|
||||||
void Launch();
|
void Launch();
|
||||||
void Release();
|
void Release();
|
||||||
void AddGradTensor(const tensor::TensorPtr &tensor);
|
void AddGradTensor(const tensor::TensorPtr &tensor);
|
||||||
virtual void Init() = 0;
|
virtual void Init(const std::vector<void *> &compute_streams, const std::vector<void *> &communication_streams) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
uint32_t id_;
|
uint32_t id_;
|
||||||
|
|
|
@ -165,14 +165,18 @@ std::shared_ptr<LaunchKernel> GPUBucket::CreateLaunchMul() {
|
||||||
return launch_mul;
|
return launch_mul;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUBucket::Init() {
|
void GPUBucket::Init(const std::vector<void *> &compute_streams, const std::vector<void *> &communication_streams) {
|
||||||
pre_event_ = std::make_shared<GpuEvent>();
|
pre_event_ = std::make_shared<GpuEvent>();
|
||||||
post_event_ = std::make_shared<GpuEvent>();
|
post_event_ = std::make_shared<GpuEvent>();
|
||||||
|
|
||||||
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
if (!compute_streams.empty()) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
compute_stream_ = compute_streams.front();
|
||||||
stream_ = kernel_runtime->communication_stream();
|
}
|
||||||
compute_stream_ = kernel_runtime->compute_stream();
|
if (!communication_streams.empty()) {
|
||||||
|
stream_ = communication_streams.front();
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_stream_);
|
||||||
|
MS_EXCEPTION_IF_NULL(stream_);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(pre_event_);
|
MS_EXCEPTION_IF_NULL(pre_event_);
|
||||||
MS_EXCEPTION_IF_NULL(post_event_);
|
MS_EXCEPTION_IF_NULL(post_event_);
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include "runtime/device/bucket.h"
|
#include "runtime/device/bucket.h"
|
||||||
|
|
||||||
namespace mindspore::device::gpu {
|
namespace mindspore::device::gpu {
|
||||||
|
@ -26,7 +27,7 @@ class GPUBucket : public Bucket {
|
||||||
GPUBucket(uint32_t id, uint32_t bucket_size);
|
GPUBucket(uint32_t id, uint32_t bucket_size);
|
||||||
~GPUBucket() override = default;
|
~GPUBucket() override = default;
|
||||||
|
|
||||||
void Init() override;
|
void Init(const std::vector<void *> &compute_streams, const std::vector<void *> &communication_streams) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void AllocateAllReduceAddr() override;
|
void AllocateAllReduceAddr() override;
|
||||||
|
|
|
@ -225,6 +225,9 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
|
||||||
|
|
||||||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
|
session_->InitAllBucket(graph, device_context_);
|
||||||
|
|
||||||
return graph->graph_id();
|
return graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,7 +315,8 @@ void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const
|
||||||
void GraphCompiler::RecoverGraphOutput(
|
void GraphCompiler::RecoverGraphOutput(
|
||||||
const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||||
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
||||||
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs) {
|
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs,
|
||||||
|
std::vector<TensorPtr> *runop_output_tensors) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
MS_EXCEPTION_IF_NULL(op_output_map);
|
MS_EXCEPTION_IF_NULL(op_output_map);
|
||||||
MS_EXCEPTION_IF_NULL(outputs);
|
MS_EXCEPTION_IF_NULL(outputs);
|
||||||
|
@ -349,8 +353,19 @@ void GraphCompiler::RecoverGraphOutput(
|
||||||
|
|
||||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||||
tensor_ref = output_tensor;
|
tensor_ref = output_tensor;
|
||||||
|
runop_output_tensors->emplace_back(output_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
|
||||||
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
|
session_->AddGradAddrToBucket(graph_id, grad_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
|
||||||
|
MS_EXCEPTION_IF_NULL(session_);
|
||||||
|
session_->ClearAllBucket(graph_id);
|
||||||
|
}
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -82,7 +82,17 @@ class GraphCompiler {
|
||||||
// Handle single op output tensor and recover output of original complete kernel graph.
|
// Handle single op output tensor and recover output of original complete kernel graph.
|
||||||
void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
void RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||||
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
||||||
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs);
|
std::map<KernelWithIndex, TensorPtr> *op_output_map, VectorRef *outputs,
|
||||||
|
std::vector<TensorPtr> *runop_output_tensors);
|
||||||
|
|
||||||
|
// Collect output tensors of back propagation graph for allreduce operators to average gradient,
|
||||||
|
// used in PyNative distributed training mode.
|
||||||
|
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
|
||||||
|
|
||||||
|
// Clear resource in bucket, such as useless tensors and device memory of all communication operators,
|
||||||
|
// Bucket is used in PyNative distributed training mode, one bucket handles all resource to launch and sync allreduce
|
||||||
|
// operator.
|
||||||
|
void ClearAllBucket(const GraphId &graph_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphCompiler() = default;
|
GraphCompiler() = default;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "runtime/device/device_address.h"
|
#include "runtime/device/device_address.h"
|
||||||
|
#include "runtime/device/bucket.h"
|
||||||
#include "backend/session/kernel_graph.h"
|
#include "backend/session/kernel_graph.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
@ -97,6 +98,10 @@ class DeviceContext {
|
||||||
// Get device_context_key_ to obtain device name and device id.
|
// Get device_context_key_ to obtain device name and device id.
|
||||||
const DeviceContextKey &device_context_key() const { return device_context_key_; }
|
const DeviceContextKey &device_context_key() const { return device_context_key_; }
|
||||||
|
|
||||||
|
// Create and initialize bucket for every allreduce operator. Bucket is used in PyNative distributed training mode,
|
||||||
|
// one bucket handles all resource to launch and sync allreduce operator.
|
||||||
|
virtual std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const { return nullptr; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DeviceContextKey device_context_key_;
|
DeviceContextKey device_context_key_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "runtime/hardware/gpu/optimizer.h"
|
#include "runtime/hardware/gpu/optimizer.h"
|
||||||
#include "common/trans.h"
|
#include "common/trans.h"
|
||||||
#include "utils/context/graph_kernel_flags.h"
|
#include "utils/context/graph_kernel_flags.h"
|
||||||
|
#include "runtime/device/gpu/gpu_bucket.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -94,12 +95,16 @@ bool GPUDeviceContext::InitDevice() {
|
||||||
|
|
||||||
// Initialize device resource, such as stream, cudnn and cublas handle.
|
// Initialize device resource, such as stream, cudnn and cublas handle.
|
||||||
GPUDeviceManager::GetInstance().InitDevice();
|
GPUDeviceManager::GetInstance().InitDevice();
|
||||||
|
|
||||||
auto stream = GPUDeviceManager::GetInstance().default_stream();
|
auto stream = GPUDeviceManager::GetInstance().default_stream();
|
||||||
if (stream == nullptr) {
|
MS_ERROR_IF_NULL(stream);
|
||||||
MS_LOG(ERROR) << "No default CUDA stream found.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
streams_.push_back(stream);
|
streams_.push_back(stream);
|
||||||
|
|
||||||
|
void *communication_stream = nullptr;
|
||||||
|
GPUDeviceManager::GetInstance().CreateStream(&communication_stream);
|
||||||
|
MS_ERROR_IF_NULL(communication_stream);
|
||||||
|
streams_.push_back(communication_stream);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,6 +291,14 @@ bool GPUDeviceContext::SyncStream(size_t stream_id) const {
|
||||||
return GPUDeviceManager::GetInstance().SyncStream(streams_[stream_id]);
|
return GPUDeviceManager::GetInstance().SyncStream(streams_[stream_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Bucket> GPUDeviceContext::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const {
|
||||||
|
auto bucket = std::make_shared<GPUBucket>(bucket_id, bucket_size);
|
||||||
|
MS_EXCEPTION_IF_NULL(bucket);
|
||||||
|
|
||||||
|
bucket->Init({streams_[0]}, {streams_[1]});
|
||||||
|
return bucket;
|
||||||
|
}
|
||||||
|
|
||||||
bool GPUDeviceContext::BindDeviceToCurrentThread() const {
|
bool GPUDeviceContext::BindDeviceToCurrentThread() const {
|
||||||
if (cur_thread_device_inited) {
|
if (cur_thread_device_inited) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -60,6 +60,10 @@ class GPUDeviceContext : public DeviceContext {
|
||||||
|
|
||||||
bool SyncStream(size_t stream_id = 0) const override;
|
bool SyncStream(size_t stream_id = 0) const override;
|
||||||
|
|
||||||
|
// Create bucket for every allreduce operator. Bucket is used in PyNative distributed training mode, one bucket
|
||||||
|
// handles all resource to launch and sync allreduce operator.
|
||||||
|
std::shared_ptr<Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
|
||||||
bool InitDevice();
|
bool InitDevice();
|
||||||
|
|
|
@ -355,6 +355,11 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
||||||
runtime::GraphCompiler::GetInstance().GetParamAndOutputIndex(graph, inputs[graph_index], outputs, ¶meter_index,
|
runtime::GraphCompiler::GetInstance().GetParamAndOutputIndex(graph, inputs[graph_index], outputs, ¶meter_index,
|
||||||
&output_indexes);
|
&output_indexes);
|
||||||
|
|
||||||
|
// Clear bucket resources every step
|
||||||
|
if (graph->is_bprop()) {
|
||||||
|
runtime::GraphCompiler::GetInstance().ClearAllBucket(graph->graph_id());
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto &kernel : graph->execution_order()) {
|
for (const auto &kernel : graph->execution_order()) {
|
||||||
OpRunInfo op_run_info;
|
OpRunInfo op_run_info;
|
||||||
GraphInfo graph_info;
|
GraphInfo graph_info;
|
||||||
|
@ -369,8 +374,14 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
||||||
VectorRef op_outputs =
|
VectorRef op_outputs =
|
||||||
RunGraph(actor_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
|
RunGraph(actor_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
|
||||||
|
|
||||||
|
std::vector<tensor::TensorPtr> new_output_tensors;
|
||||||
runtime::GraphCompiler::GetInstance().RecoverGraphOutput(kernel, op_outputs, output_indexes, &op_output_map,
|
runtime::GraphCompiler::GetInstance().RecoverGraphOutput(kernel, op_outputs, output_indexes, &op_output_map,
|
||||||
outputs);
|
outputs, &new_output_tensors);
|
||||||
|
|
||||||
|
// Save grad node to Bucket
|
||||||
|
if (graph->is_bprop()) {
|
||||||
|
runtime::GraphCompiler::GetInstance().AddGradAddrToBucket(graph->graph_id(), new_output_tensors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -529,6 +540,13 @@ VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const std::vector
|
||||||
}
|
}
|
||||||
|
|
||||||
mindspore::ScopedLongRunning long_running;
|
mindspore::ScopedLongRunning long_running;
|
||||||
|
|
||||||
|
for (auto &tensor : tensors_without_value_node) {
|
||||||
|
if (tensor->NeedWaitDevice()) {
|
||||||
|
tensor->WaitDevice();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, {tensors_without_value_node});
|
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, {tensors_without_value_node});
|
||||||
if (!runtime::GraphScheduler::GetInstance().Run(actor_set, runtime::GraphExecutionStrategy::kStep, input_tensors)) {
|
if (!runtime::GraphScheduler::GetInstance().Run(actor_set, runtime::GraphExecutionStrategy::kStep, input_tensors)) {
|
||||||
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
|
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
|
||||||
|
|
Loading…
Reference in New Issue