forked from mindspore-Ecosystem/mindspore
!274 GPU multiple stream feature
Merge pull request !274 from ZPaC/gpu-backend-supports-multiple-streams
This commit is contained in:
commit
57aee805ce
|
@ -25,7 +25,7 @@ namespace device {
|
|||
namespace gpu {
|
||||
void GPUDeviceManager::InitDevice() {
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(&stream_), "Failed to create CUDA stream.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream.");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
|
||||
"Failed to set stream for cuDNN handle.");
|
||||
|
@ -36,19 +36,27 @@ void GPUDeviceManager::InitDevice() {
|
|||
}
|
||||
|
||||
void GPUDeviceManager::ReleaseDevice() {
|
||||
if (stream_ != nullptr) {
|
||||
CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream_), "Failed to destroy cuda stream.");
|
||||
for (DeviceStream stream : gpu_streams_) {
|
||||
if (stream != nullptr) {
|
||||
CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream.");
|
||||
}
|
||||
}
|
||||
if (cudnn_handle_ != nullptr) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cudnn handle");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle");
|
||||
}
|
||||
if (cublas_handle_ != nullptr) {
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cublas handle.");
|
||||
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle.");
|
||||
}
|
||||
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
|
||||
}
|
||||
|
||||
const DeviceStream& GPUDeviceManager::default_stream() const { return stream_; }
|
||||
bool GPUDeviceManager::CreateStream(DeviceStream* stream) {
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
|
||||
gpu_streams_.emplace_back(*stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; }
|
||||
|
||||
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <cudnn.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "device/gpu/cuda_driver.h"
|
||||
#include "device/gpu/gpu_memory_allocator.h"
|
||||
|
@ -36,13 +37,15 @@ class GPUDeviceManager {
|
|||
uint32_t cur_device_id() const;
|
||||
bool is_device_id_init() const;
|
||||
|
||||
bool CreateStream(DeviceStream* stream);
|
||||
bool SyncStream(const DeviceStream& stream) const;
|
||||
const DeviceStream& default_stream() const;
|
||||
|
||||
const cudnnHandle_t& GetCudnnHandle() const;
|
||||
const cublasHandle_t& GetCublasHandle() const;
|
||||
|
||||
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const;
|
||||
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const;
|
||||
bool SyncStream(const DeviceStream& stream) const;
|
||||
|
||||
static GPUDeviceManager& GetInstance() {
|
||||
static GPUDeviceManager instance;
|
||||
|
@ -55,13 +58,16 @@ class GPUDeviceManager {
|
|||
GPUDeviceManager(const GPUDeviceManager&) = delete;
|
||||
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete;
|
||||
|
||||
// default cuda stream used for all the kernels.
|
||||
DeviceStream stream_{nullptr};
|
||||
// default CUDA stream used for all the kernels.
|
||||
DeviceStream default_stream_{nullptr};
|
||||
|
||||
// handle used for cudnn kernels.
|
||||
// all gpu CUDA streams including default_stream_.
|
||||
std::vector<DeviceStream> gpu_streams_;
|
||||
|
||||
// handle used for cuDNN kernels.
|
||||
cudnnHandle_t cudnn_handle_{nullptr};
|
||||
|
||||
// handle used for cublas kernels.
|
||||
// handle used for cuBLAS kernels.
|
||||
cublasHandle_t cublas_handle_{nullptr};
|
||||
|
||||
bool dev_id_init_;
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "device/gpu/gpu_common.h"
|
||||
#include "device/gpu/kernel_info_setter.h"
|
||||
#include "device/gpu/gpu_device_manager.h"
|
||||
#include "device/gpu/gpu_stream_assign.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<CNodePtr> allreduce_cnodes;
|
||||
auto execution_kernels = kernel_graph->execution_order();
|
||||
for (auto kernel : execution_kernels) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (kernel_name == kAllReduceOpName) {
|
||||
allreduce_cnodes.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
if (allreduce_cnodes.size() > 1) {
|
||||
DeviceStream comm_stream = nullptr;
|
||||
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
|
||||
std::transform(allreduce_cnodes.begin(), allreduce_cnodes.end(), allreduce_cnodes.begin(), [&](CNodePtr node) {
|
||||
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), node);
|
||||
return node;
|
||||
});
|
||||
|
||||
std::vector<SendRecvPair> send_recv_pairs;
|
||||
FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs);
|
||||
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
|
||||
}
|
||||
}
|
||||
|
||||
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
std::vector<SendRecvPair> *send_recv_pairs) {
|
||||
auto execution_kernels = kernel_graph->execution_order();
|
||||
std::vector<CNodePtr>::iterator iter, iter_begin;
|
||||
iter = iter_begin = execution_kernels.begin();
|
||||
std::vector<CNodePtr>::iterator iter_end = execution_kernels.end();
|
||||
for (; iter != execution_kernels.end(); ++iter) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(*iter);
|
||||
if (kernel_name == kAllReduceOpName) {
|
||||
// Find AllReduce node's last input node.
|
||||
std::vector<CNodePtr>::iterator mock_send_node_iter =
|
||||
FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch);
|
||||
if (mock_send_node_iter == iter + 1) {
|
||||
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
|
||||
continue;
|
||||
}
|
||||
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
|
||||
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
|
||||
send_recv_pairs->push_back(pair1);
|
||||
// Find node which uses AllReduce as input[0].
|
||||
std::vector<CNodePtr>::iterator mock_recv_node_iter =
|
||||
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
|
||||
if (mock_recv_node_iter == iter_end) {
|
||||
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
|
||||
continue;
|
||||
}
|
||||
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
|
||||
IntToSize(mock_recv_node_iter - iter_begin)};
|
||||
send_recv_pairs->push_back(pair2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
|
||||
std::vector<CNodePtr>::iterator end, const CNodePtr mock_recv_node,
|
||||
StreamSwitchType stream_switch_type) {
|
||||
MS_EXCEPTION_IF_NULL(mock_recv_node);
|
||||
if (stream_switch_type == kAllReduceStreamSwitch) {
|
||||
for (auto iter = begin; iter != end; iter++) {
|
||||
if (*(iter + 1) == mock_recv_node) {
|
||||
return iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
return end;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator begin,
|
||||
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
|
||||
StreamSwitchType stream_switch_type) {
|
||||
MS_EXCEPTION_IF_NULL(mock_send_node);
|
||||
for (auto iter = begin; iter != end; iter++) {
|
||||
auto node = *iter;
|
||||
if (stream_switch_type == kAllReduceStreamSwitch) {
|
||||
for (auto input : node->inputs()) {
|
||||
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
|
||||
return iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return end;
|
||||
}
|
||||
|
||||
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const std::vector<SendRecvPair> &send_recv_pairs) {
|
||||
std::set<StreamSwitchNode> ordered_stream_switch_nodes;
|
||||
for (SendRecvPair pair : send_recv_pairs) {
|
||||
StreamSwitchType stream_switch_type = pair.stream_switch_type;
|
||||
CNodePtr mock_send_node = pair.mock_send_node;
|
||||
CNodePtr mock_recv_node = pair.mock_recv_node;
|
||||
size_t send_node_offset = pair.send_node_offset;
|
||||
size_t recv_node_offset = pair.recv_node_offset;
|
||||
CNodePtr send_node = nullptr;
|
||||
CNodePtr recv_node = nullptr;
|
||||
// Step 1: generate Send and Recv CNodes.
|
||||
if (stream_switch_type == kAllReduceStreamSwitch) {
|
||||
if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) {
|
||||
MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch";
|
||||
}
|
||||
}
|
||||
// Step 2: sort send and recv CNodes by offset.
|
||||
ordered_stream_switch_nodes.insert({send_node_offset, send_node});
|
||||
ordered_stream_switch_nodes.insert({recv_node_offset, recv_node});
|
||||
}
|
||||
// Step 3: insert stream switch CNodes into execution kernel list.
|
||||
auto execution_kernels = kernel_graph->execution_order();
|
||||
for (auto node = ordered_stream_switch_nodes.begin(); node != ordered_stream_switch_nodes.end(); node++) {
|
||||
execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode);
|
||||
}
|
||||
kernel_graph->set_execution_order(execution_kernels);
|
||||
}
|
||||
|
||||
bool GenSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node,
|
||||
CNodePtr *recv_node) {
|
||||
*send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName);
|
||||
MS_EXCEPTION_IF_NULL(*send_node);
|
||||
*recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName);
|
||||
MS_EXCEPTION_IF_NULL(*recv_node);
|
||||
|
||||
cudaEvent_t event = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed.");
|
||||
AnfAlgo::SetNodeAttr("record_event", MakeValue(reinterpret_cast<uintptr_t>(event)), *send_node);
|
||||
AnfAlgo::SetNodeAttr("wait_event", MakeValue(reinterpret_cast<uintptr_t>(event)), *recv_node);
|
||||
|
||||
uintptr_t send_stream = AnfAlgo::GetNodeAttr<uintptr_t>(mock_send_node, "stream_id");
|
||||
AnfAlgo::SetNodeAttr("record_event_stream", MakeValue(send_stream), *send_node);
|
||||
uintptr_t recv_stream = AnfAlgo::GetNodeAttr<uintptr_t>(mock_recv_node, "stream_id");
|
||||
AnfAlgo::SetNodeAttr("wait_event_stream", MakeValue(recv_stream), *recv_node);
|
||||
return true;
|
||||
}
|
||||
|
||||
CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &name) {
|
||||
auto op = std::make_shared<Primitive>(name);
|
||||
auto apply = std::make_shared<ValueNode>(op);
|
||||
std::vector<AnfNodePtr> input_list = {apply};
|
||||
CNodePtr node = kernel_graph->NewCNode(input_list);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get());
|
||||
auto abstract_none = std::make_shared<abstract::AbstractNone>();
|
||||
node->set_abstract(abstract_none);
|
||||
SetKernelInfo(node);
|
||||
return node;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
|
||||
#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "session/kernel_graph.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 };
|
||||
struct SendRecvPair {
|
||||
StreamSwitchType stream_switch_type;
|
||||
CNodePtr mock_send_node;
|
||||
CNodePtr mock_recv_node;
|
||||
size_t send_node_offset;
|
||||
size_t recv_node_offset;
|
||||
};
|
||||
struct StreamSwitchNode {
|
||||
size_t offset;
|
||||
CNodePtr cnode;
|
||||
bool operator<(const StreamSwitchNode &n) const {
|
||||
if (offset < n.offset) {
|
||||
return true;
|
||||
} else if (offset == n.offset) {
|
||||
return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
std::vector<SendRecvPair> *send_recv_pairs);
|
||||
// Find Send node position according to "mock" recv node.
|
||||
// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node.
|
||||
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
|
||||
std::vector<CNodePtr>::iterator end, const CNodePtr mock_recv_node,
|
||||
StreamSwitchType stream_switch_type);
|
||||
// Find Recv node position according to "mock" send node.
|
||||
// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node.
|
||||
std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator begin,
|
||||
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
|
||||
StreamSwitchType stream_switch_type);
|
||||
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const std::vector<SendRecvPair> &send_recv_pairs);
|
||||
bool GenSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node,
|
||||
CNodePtr *recv_node);
|
||||
CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &name);
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -52,7 +52,8 @@ class NcclGpuKernel : public GpuKernel {
|
|||
nccl_reduce_type_(ncclSum),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
collective_handle_(nullptr) {}
|
||||
collective_handle_(nullptr),
|
||||
comm_stream_(nullptr) {}
|
||||
~NcclGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -63,34 +64,33 @@ class NcclGpuKernel : public GpuKernel {
|
|||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
switch (nccl_kernel_type_) {
|
||||
case NCCL_ALL_REDUCE: {
|
||||
auto all_reduce_funcptr =
|
||||
reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
(*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"ncclAllReduce failed");
|
||||
CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
|
||||
nccl_data_type_, nccl_reduce_type_, stream),
|
||||
"ncclAllReduce failed");
|
||||
break;
|
||||
}
|
||||
case NCCL_ALL_GATHER: {
|
||||
auto all_gather_funcptr =
|
||||
reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
|
||||
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT((*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T),
|
||||
nccl_data_type_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"ncclAllGather failed");
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
(*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream),
|
||||
"ncclAllGather failed");
|
||||
break;
|
||||
}
|
||||
case NCCL_REDUCE_SCATTER: {
|
||||
auto reduce_scatter_funcptr =
|
||||
reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
|
||||
MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
(*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_,
|
||||
nccl_reduce_type_, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"ncclReduceScatter failed");
|
||||
CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
|
||||
nccl_data_type_, nccl_reduce_type_, stream),
|
||||
"ncclReduceScatter failed");
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -167,6 +167,7 @@ class NcclGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
const void *collective_handle_;
|
||||
cudaStream_t comm_stream_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "device/gpu/kernel_info_setter.h"
|
||||
#include "device/gpu/gpu_kernel_build.h"
|
||||
#include "device/gpu/gpu_kernel_runtime.h"
|
||||
#include "device/gpu/gpu_stream_assign.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
|
||||
|
@ -55,6 +56,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
device::gpu::AssignGpuStream(kernel_graph);
|
||||
}
|
||||
|
||||
void GPUSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
device::gpu::GpuBuild(kernel_graph);
|
||||
}
|
||||
|
@ -94,6 +100,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
|
|||
StartKernelRT();
|
||||
// AllReduce Optimize
|
||||
Optimize(graph);
|
||||
// Assign CUDA streams
|
||||
AssignStream(graph);
|
||||
// Build kernel if node is cnode
|
||||
BuildKernel(graph);
|
||||
// Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph
|
||||
|
|
|
@ -49,6 +49,8 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
||||
void AllocateMemory(KernelGraph *kernel_graph) const;
|
||||
|
|
|
@ -113,6 +113,8 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
|
|||
constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum";
|
||||
constexpr auto kBiasAddOpName = "BiasAdd";
|
||||
constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad";
|
||||
constexpr auto kSendOpName = "Send";
|
||||
constexpr auto kRecvOpName = "Recv";
|
||||
|
||||
// attr key name
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
|
Loading…
Reference in New Issue