PyNative AllReduce Bucket

This commit is contained in:
caifubi 2021-02-27 16:39:57 +08:00
parent d2d6c3cfb5
commit 171b468bb3
41 changed files with 1246 additions and 88 deletions

View File

@ -64,6 +64,7 @@
#include "toolchain/adx_datadump_server.h"
#ifdef ENABLE_DUMP_IR
#include "debug/rdr/running_data_recorder.h"
#include "runtime/device/ascend/ascend_bucket.h"
#endif
#if ENABLE_CPU && ENABLE_D
#include "ps/util.h"
@ -258,6 +259,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
// construct graph, if successfully, graph_sum_ + 1
auto graph = ConstructKernelGraph(lst, outputs);
auto graph_id = graph->graph_id();
InitAllBucket(graph);
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
@ -632,6 +634,13 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// Run op
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
@ -1510,5 +1519,9 @@ void AscendSession::SyncStream() {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
}
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
}
} // namespace session
} // namespace mindspore

View File

@ -61,6 +61,7 @@ class AscendSession : public SessionBasic {
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, size_t> &cnode_refcount) override;
std::string GetCommWorldGroup() override { return kHcclWorldGroup; }
private:
// compile child graph when session have multiple child graphs
@ -123,6 +124,7 @@ class AscendSession : public SessionBasic {
const std::vector<tensor::TensorPtr> &graph_inputs,
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
InputTensorInfo *input_tensor_info);
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
// key is final_graph_id,value is child graph execute order of final graph
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
// key is final_graph_id,value is the graph types of child graphs

View File

@ -16,6 +16,7 @@
#include "backend/session/gpu_session.h"
#include <string>
#include <utility>
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
@ -63,6 +64,7 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/cuda_driver.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "runtime/device/gpu/gpu_bucket.h"
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
@ -394,6 +396,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
InitAllBucket(graph);
// Alloc memory in graph mode, including static memory and dynamic memory
if (!pynative_mode) {
AllocateMemory(graph.get());
@ -473,6 +477,12 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
MS_EXCEPTION_IF_NULL(op_run_info);
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// run op
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -548,6 +558,10 @@ void GPUSession::SyncStream() {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
}
std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
return std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size);
}
} // namespace gpu
} // namespace session
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include <algorithm>
#include <string>
#include "backend/session/session_basic.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/session_factory.h"
@ -44,6 +45,8 @@ class GPUSession : public SessionBasic {
const std::vector<int64_t> &tensors_mask) override;
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;

View File

@ -40,6 +40,7 @@
#include "debug/anf_ir_dump.h"
#include "debug/common.h"
#include "utils/trace_base.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/constants.h"
@ -556,10 +557,12 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
const std::map<KernelWithIndex, size_t> &ref_count,
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs,
std::vector<TensorPtr> *runop_output_tensors) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(op_output_map);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(runop_output_tensors);
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
if (output_tensors.size() > op_outputs.size()) {
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
@ -592,6 +595,7 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
}
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
tensor_ref = output_tensor;
runop_output_tensors->emplace_back(output_tensor);
}
}
}
@ -2196,6 +2200,11 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
GetRefCount(kernel_graph.get(), &cnode_refcount);
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
// Clear bucket resources every step
if (kernel_graph->is_bprop()) {
ClearAllBucket(graph_id);
}
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
for (const auto &kernel : kernel_graph->execution_order()) {
// Generate input tensors, tensor masks and input kernel with index
@ -2212,9 +2221,15 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
input_tensor_info.input_tensors_mask);
std::vector<tensor::TensorPtr> new_output_tensors;
// Handle inputs and outputs of current op
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs, &new_output_tensors);
// Save grad node to Bucket
if (kernel_graph->is_bprop()) {
AddGradAddrToBucket(graph_id, new_output_tensors);
}
}
MS_LOG(INFO) << "Finish!";
}
@ -2287,6 +2302,137 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
}
}
std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string group = GetCommWorldGroup();
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
// PyNative not support multi group allreduce
group += "sum1";
return parallel_context->GetAllReduceFusionSplitIndices(group);
}
uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
}
void SetGraphBpropAttr(const KernelGraphPtr &graph) {
auto &execution_orders = graph->execution_order();
if (std::any_of(execution_orders.begin(), execution_orders.end(),
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
graph->set_is_bprop(true);
MS_LOG(INFO) << "Match bprop graph";
} else {
graph->set_is_bprop(false);
}
}
std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
if (split_index.empty()) {
auto grads_count = GetBpropGraphGradsCount(graph);
if (grads_count == 0) {
MS_LOG(EXCEPTION) << "Bprop graph has no grad";
}
return {grads_count};
}
std::vector<uint32_t> bucket_size_list;
uint32_t old_index = 0;
for (auto &index : split_index) {
if (old_index == 0) {
bucket_size_list.emplace_back(index - old_index + 1);
} else {
bucket_size_list.emplace_back(index - old_index);
}
old_index = index;
}
return bucket_size_list;
}
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
SetGraphBpropAttr(graph);
if (!graph->is_bprop()) {
return;
}
std::vector<std::shared_ptr<device::Bucket>> bucket_list;
// Create bucket for every split allreduce ops
auto split_index = GetAllReduceSplitIndex();
auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
uint32_t bucket_id = 0;
for (auto bucket_size : bucket_size_list) {
MS_LOG(INFO) << "Create new bucket:" << bucket_id;
auto bucket = CreateBucket(bucket_id++, bucket_size);
bucket->Init();
bucket_list.emplace_back(bucket);
}
auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
if (!bucket_ret.second) {
MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
}
// set all free bucket index to 0
auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
if (!free_bucket_ret.second) {
MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
}
MS_LOG(INFO) << "Init Bucket finish";
}
void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
if (parallel_mode != parallel::DATA_PARALLEL) {
return;
}
auto iter = bucket_map_.find(graph_id);
if (iter == bucket_map_.end()) {
MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
}
auto &bucket_list = iter->second;
auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
if (free_bucket_iter == free_bucket_id_map_.end()) {
MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
}
auto free_bucket_index = free_bucket_iter->second;
for (auto &tensor : grad_tensor) {
if (free_bucket_index >= bucket_list.size()) {
MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
<< " total bucket num:" << bucket_list.size();
}
auto &free_bucket = bucket_list[free_bucket_index];
free_bucket->AddGradTensor(tensor);
if (free_bucket->full()) {
MS_LOG(INFO) << "bucket is full";
free_bucket->Launch();
free_bucket_index = ++free_bucket_iter->second;
MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
}
}
}
void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
auto iter = bucket_map_.find(graph_id);
if (iter != bucket_map_.end()) {
auto bucket_list = iter->second;
for (auto &bucket : bucket_list) {
MS_LOG(INFO) << "Clear bucket:" << bucket->id();
bucket->Release();
}
}
auto free_iter = free_bucket_id_map_.find(graph_id);
if (free_iter != free_bucket_id_map_.end()) {
free_iter->second = 0;
}
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
if (!ps::PSContext::instance()->is_worker()) {

View File

@ -32,6 +32,7 @@
#include "utils/contract.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_context.h"
#include "runtime/device/bucket.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "debug/debugger/debugger.h"
#endif
@ -224,12 +225,20 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
void InitAllBucket(const KernelGraphPtr &graph);
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
void ClearAllBucket(const GraphId &graph_id);
std::vector<uint32_t> GetAllReduceSplitIndex();
virtual std::string GetCommWorldGroup() { return std::string(); }
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
void GetBatchElements(const AnfNodePtr &kernel_node) const;
void InitPsWorker(const KernelGraphPtr &kernel_graph);
#endif
std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
std::map<uint32_t, uint32_t> free_bucket_id_map_;
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;

View File

@ -677,34 +677,9 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info;
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
MS_EXCEPTION_IF_NULL(op_exec_info);
void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
auto prim = op_exec_info->py_primitive;
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim));
const auto &signature = prim->signatures();
auto sig_size = signature.size();
auto size = op_exec_info->op_inputs.size();
// ignore monad signature
for (auto sig : signature) {
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
--sig_size;
}
}
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
<< "inputs size " << sig_size;
}
if (op_exec_info->op_name != prim::kPrimCast->name()) {
RunParameterAutoMixPrecisionCast(op_exec_info);
}
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
abstract::AbstractBasePtr abs = nullptr;
const auto &obj = op_exec_info->op_inputs[i];
@ -733,11 +708,42 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
if (input_node->abstract() != nullptr) {
abs = input_node->abstract();
}
inputs.emplace_back(input_node);
inputs->emplace_back(input_node);
}
}
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
}
}
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list) {
MS_EXCEPTION_IF_NULL(op_masks);
MS_EXCEPTION_IF_NULL(args_spec_list);
MS_EXCEPTION_IF_NULL(op_exec_info);
auto prim = op_exec_info->py_primitive;
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(prim));
const auto &signature = prim->signatures();
auto sig_size = signature.size();
auto size = op_exec_info->op_inputs.size();
// ignore monad signature
for (auto sig : signature) {
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
--sig_size;
}
}
if (sig_size > 0 && sig_size != size) {
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
<< "inputs size " << sig_size;
}
if (op_exec_info->op_name != prim::kPrimCast->name()) {
RunParameterAutoMixPrecisionCast(op_exec_info);
}
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list);
CNodePtr cnode = nullptr;
if (need_construct_graph()) {

View File

@ -208,6 +208,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
PynativeStatusCode *const status);
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,

View File

@ -1,5 +1,7 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"bucket.cc"
)
if(ENABLE_GPU)

View File

@ -0,0 +1,173 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_bucket.h"
#include <vector>
#include <memory>
#include "runtime/mem.h"
#include "external/hccl/hccl.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "backend/kernel_compiler/hccl/hcom_util.h"
#include "backend/kernel_compiler/hccl/hccl_context.h"
#include "runtime/device/memory_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/ascend/ascend_event.h"
#include "utils/profile.h"
#define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \
{ \
rtError_t ret = (expression); \
if (ret != RT_ERROR_NONE) { \
MS_LOG(EXCEPTION) << message << ", error code: " << ret; \
} \
}
namespace mindspore::device::ascend {
void AscendBucket::AllocateAllReduceAddr() {
// Check bucket is full
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
<< " is not equal to bucket size:" << bucket_size_;
}
auto total_size = 0;
std::vector<size_t> align_size_list;
std::vector<size_t> origin_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor_type_list_.emplace_back(tensor->data_type());
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
auto origin_size = device_address->GetSize();
auto align_size = MemoryManager::GetCommonAlignSize(origin_size);
origin_size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(std::make_shared<kernel::Address>(
static_cast<uint8_t *>(device_address->GetMutablePtr()), device_address->GetSize()));
}
total_size_ = total_size;
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(runtime_instance);
// AllReduce input output addr need to clear zero
ar_input_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
ar_output_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
// generate memecpy output addr
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, origin_size_list[i]));
memcpy_output += align_size_list[i];
}
// store output tensor addr
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
}
}
void AscendBucket::FreeDeviceMem(void *dev_ptr) { AscendMemoryPool::GetInstance().FreeTensorMem(dev_ptr); }
void AscendBucket::FreeAllDeviceMem() {
if (ar_input_addr_ != nullptr) {
uint8_t *origin_dev_addr = ar_input_addr_ - kMemAlignSize;
FreeDeviceMem(origin_dev_addr);
ar_input_addr_ = nullptr;
}
if (ar_output_addr_ != nullptr) {
uint8_t *origin_dev_addr = ar_output_addr_ - kMemAlignSize;
FreeDeviceMem(origin_dev_addr);
ar_output_addr_ = nullptr;
}
}
void AscendBucket::CopyTensorToContiguousMemory() {
// Clean input addr
CHECK_ASCEND_RT_WITH_EXCEPTION(rtMemsetAsync(ar_input_addr_, total_size_, 0, total_size_, compute_stream_),
"Call rtMemsetAsync failed");
for (size_t i = 0; i < bucket_size_; ++i) {
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
MS_LOG(DEBUG) << "MemcpyAsync dst size:" << memcpy_output_addrs_[i]->size
<< " src size:" << memcpy_input_addrs_[i]->size;
if (memcpy_output_addrs_[i]->size < memcpy_input_addrs_[i]->size) {
MS_LOG(EXCEPTION) << "rtMemcpyAsync dst size < src size";
}
CHECK_ASCEND_RT_WITH_EXCEPTION(
rtMemcpyAsync(memcpy_output_addrs_[i]->addr, memcpy_output_addrs_[i]->size, memcpy_input_addrs_[i]->addr,
memcpy_input_addrs_[i]->size, RT_MEMCPY_DEVICE_TO_DEVICE, compute_stream_),
"Call rtMemcpyAsync failed");
}
}
void AscendBucket::LaunchAllReduce() {
if (tensor_type_list_.empty()) {
MS_LOG(EXCEPTION) << "No tesnor type found";
}
// AllReduce inputs data type should be same
auto type = tensor_type_list_[0];
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
[&type](TypeId tensor_type) { return type != tensor_type; })) {
MS_LOG(EXCEPTION) << "allreduce input have different dtype";
}
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(EXCEPTION) << "unknown data type:" << type;
}
uint32_t type_size;
if (!HcomUtil::GetHcomTypeSize(iter->second, &type_size)) {
MS_LOG(EXCEPTION) << "get hcom type size fialed";
}
if (type_size == 0 || total_size_ % type_size != 0) {
MS_LOG(EXCEPTION) << "Total_size[" << total_size_ << "],Type_size[" << type_size << "] != 0, fail!";
}
auto hccl_count = total_size_ / type_size;
HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM;
auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type,
kernel::HcclContext::GetInstance().hccl_comm(), stream_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result;
}
}
void AscendBucket::Init() {
pre_event_ = std::make_shared<AscendEvent>();
post_event_ = std::make_shared<AscendEvent>();
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(kernel_runtime);
compute_stream_ = kernel_runtime->compute_stream();
stream_ = kernel_runtime->communication_stream();
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
pre_event_->set_wait_stream(stream_);
pre_event_->set_record_stream(compute_stream_);
post_event_->set_wait_stream(compute_stream_);
post_event_->set_record_stream(stream_);
}
} // namespace mindspore::device::ascend

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#include "runtime/device/bucket.h"
namespace mindspore::device::ascend {
class AscendBucket : public Bucket {
public:
AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {}
~AscendBucket() override = default;
void Init() override;
private:
void AllocateAllReduceAddr() override;
void FreeAllDeviceMem() override;
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_event.h"
#include "runtime/event.h"
#include "runtime/stream.h"
#include "utils/log_adapter.h"
namespace mindspore::device::ascend {
AscendEvent::AscendEvent() {
auto ret = rtEventCreate(&event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret;
event_ = nullptr;
}
}
AscendEvent::~AscendEvent() {
auto ret = rtEventDestroy(event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtEventDestory failed, ret:" << ret;
}
}
void AscendEvent::RecordEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(record_stream_);
auto ret = rtEventRecord(event_, record_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtEventRecord failed, ret:" << ret;
}
need_wait_ = true;
}
void AscendEvent::WaitEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(wait_stream_);
auto ret = rtStreamWaitEvent(wait_stream_, event_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret;
}
need_wait_ = false;
}
bool AscendEvent::NeedWait() { return need_wait_; }
} // namespace mindspore::device::ascend

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ASCEND_EVENT_H
#define MINDSPORE_ASCEND_EVENT_H
#include "runtime/base.h"
#include "ir/device_event.h"
namespace mindspore::device::ascend {
class AscendEvent : public DeviceEvent {
public:
AscendEvent();
~AscendEvent() override;
void WaitEvent() override;
void RecordEvent() override;
bool NeedWait() override;
void set_wait_stream(rtStream_t wait_stream) override { wait_stream_ = wait_stream; }
void set_record_stream(rtStream_t record_stream) override { record_stream_ = record_stream; }
private:
rtEvent_t event_{nullptr};
rtStream_t wait_stream_{nullptr};
rtStream_t record_stream_{nullptr};
bool need_wait_{false};
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_ASCEND_EVENT_H

View File

@ -718,6 +718,10 @@ bool AscendKernelRuntime::SyncStream() {
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
FreeAndClearBufferPtrs();
return true;
}
@ -786,6 +790,10 @@ bool AscendKernelRuntime::InitDevice() {
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
ret = rtStreamCreate(&communication_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
return true;
}
@ -799,6 +807,14 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
stream_ = nullptr;
}
if (communication_stream_ != nullptr) {
auto ret = rtStreamDestroy(communication_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
communication_stream_ = nullptr;
}
auto ret = rtDeviceReset(device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
@ -919,4 +935,5 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
return ascend_mem_manager->GetDeviceMemSize();
}
} // namespace mindspore::device::ascend

View File

@ -57,6 +57,8 @@ class AscendKernelRuntime : public KernelRuntime {
void *context() const override { return rt_context_; }
void PreInit() override;
uint64_t GetAvailableMemMaxSize() const override;
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

View File

@ -162,6 +162,14 @@ void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph *grap
somas_reuse_util_ptr_->ConvertToProfilingNode(graph->graph_id());
}
}
// communication memory: [512align_size + data + 512align_size]
// return the pointer to the start of data address.
uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
auto align_size = GetCommunicationAlignSize(size);
uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
return base_ptr + kMemAlignSize;
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager {
void *MallocMemFromMemPool(size_t size) override;
uint64_t GetDeviceMemSize();
void MallocSomasDynamicMem(const session::KernelGraph *graph);
uint8_t *MallocCommunicationMemFromMemPool(size_t size) override;
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;

View File

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/bucket.h"
#include <memory>
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/profile.h"
namespace mindspore::device {
void Bucket::AddGradTensor(const tensor::TensorPtr &tensor) {
if (grad_tensor_list_.size() >= bucket_size_) {
MS_LOG(EXCEPTION) << "bucket is full";
}
grad_tensor_list_.emplace_back(tensor);
if (grad_tensor_list_.size() > bucket_size_) {
MS_LOG(EXCEPTION) << "too many tensor add to the bucket, bucket_size_:" << bucket_size_
<< " total tensor size:" << grad_tensor_list_.size();
}
MS_LOG(INFO) << "current bucket tensors size:" << grad_tensor_list_.size();
// bucket is full, start to launch allreduce
if (grad_tensor_list_.size() == bucket_size_) {
full_ = true;
}
}
void Bucket::Launch() {
auto start = GetTime();
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "Bucket is not full, grad_tensor_list_ size:" << grad_tensor_list_.size()
<< " bucket_size_:" << bucket_size_;
}
MS_LOG(INFO) << "Bucket is full, start to launch AllReduce";
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
AllocateAllReduceAddr();
CopyTensorToContiguousMemory();
pre_event_->RecordEvent();
pre_event_->WaitEvent();
LaunchAllReduce();
post_event_->RecordEvent();
UpdateTensorAddr();
// pass event to the tensor
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor->SetDeviceEvent(post_event_);
}
MS_LOG(INFO) << "Bucket launch cost:" << (GetTime() - start) * 1e6 << " us";
}
// TODO(caifubi): float16 grad cast to float32 grad
void Bucket::UpdateTensorAddr() {
if (grad_tensor_list_.size() != bucket_size_ || new_tensor_output_addrs_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad_tensor_list size:" << grad_tensor_list_.size()
<< " tensor output addr size:" << new_tensor_output_addrs_.size()
<< " bucket size:" << bucket_size_;
}
for (size_t i = 0; i < bucket_size_; ++i) {
auto &tensor = grad_tensor_list_[i];
MS_EXCEPTION_IF_NULL(tensor);
auto device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
// release old addr and manage addr by this Bucket.
MS_EXCEPTION_IF_NULL(device_address);
auto origin_dev_ptr = device_address->GetMutablePtr();
// FreeDeviceMem(origin_dev_ptr);
tensor_old_addr_list_.emplace_back(origin_dev_ptr);
device_address->from_mem_pool_ = false;
device_address->set_ptr(new_tensor_output_addrs_[i]);
}
}
void Bucket::LazyDeleteOldAddr() {
MS_LOG(INFO) << "Lazy delete old grad address";
for (auto old_addr : tensor_old_addr_list_) {
FreeDeviceMem(old_addr);
}
tensor_old_addr_list_.clear();
}
void Bucket::Release() {
MS_LOG(INFO) << "Clear bucket:" << id_;
grad_tensor_list_.clear();
new_tensor_output_addrs_.clear();
memcpy_input_addrs_.clear();
memcpy_output_addrs_.clear();
tensor_type_list_.clear();
LazyDeleteOldAddr();
FreeAllDeviceMem();
full_ = false;
}
} // namespace mindspore::device

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/device_event.h"
#include "runtime/device/device_address.h"
#include "backend/session/kernel_graph.h"
namespace mindspore::device {
class Bucket {
public:
Bucket(uint32_t id, uint32_t bucket_size)
: id_(id),
bucket_size_(bucket_size),
full_(false),
stream_(nullptr),
compute_stream_(nullptr),
pre_event_(nullptr),
post_event_(nullptr),
total_size_(0),
ar_input_addr_(nullptr),
ar_output_addr_(nullptr) {}
virtual ~Bucket() = default;
uint32_t id() const { return id_; }
bool full() const { return full_; }
void Launch();
void Release();
void AddGradTensor(const tensor::TensorPtr &tensor);
virtual void Init() = 0;
protected:
uint32_t id_;
uint32_t bucket_size_;
bool full_;
void *stream_;
void *compute_stream_;
std::shared_ptr<DeviceEvent> pre_event_;
std::shared_ptr<DeviceEvent> post_event_;
size_t total_size_;
uint8_t *ar_input_addr_;
uint8_t *ar_output_addr_;
std::string group_;
std::vector<tensor::TensorPtr> grad_tensor_list_;
std::vector<uint8_t *> new_tensor_output_addrs_;
std::vector<kernel::AddressPtr> memcpy_input_addrs_;
std::vector<kernel::AddressPtr> memcpy_output_addrs_;
std::vector<TypeId> tensor_type_list_;
std::vector<void *> tensor_old_addr_list_;
virtual void AllocateAllReduceAddr() = 0;
void UpdateTensorAddr();
virtual void LaunchAllReduce() = 0;
virtual void FreeAllDeviceMem() = 0;
virtual void FreeDeviceMem(void *dev_ptr) = 0;
virtual void CopyTensorToContiguousMemory() = 0;
void LazyDeleteOldAddr();
};
} // namespace mindspore::device
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_

View File

@ -26,6 +26,7 @@
namespace mindspore {
namespace device {
class Bucket;
namespace cpu {
class CPUSimpleMemPlan;
class CPUMemoryManager;
@ -100,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync {
friend class mindspore::device::ascend::AscendKernelRuntime;
friend class mindspore::device::ascend::AscendMemoryManager;
friend class mindspore::device::ascend::DataDumper;
friend class mindspore::device::Bucket;
};
using DeviceAddressPtr = std::shared_ptr<DeviceAddress>;

View File

@ -0,0 +1,177 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_bucket.h"
#include <cuda_runtime_api.h>
#include <nccl.h>
#include <vector>
#include <memory>
#include "abstract/utils.h"
#include "runtime/device/gpu/gpu_event.h"
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
#include "runtime/device/gpu/gpu_common.h"
namespace {
const size_t kCommunicationMemAlignSize = 16;
size_t AlignMemorySize(size_t size) {
if (size == 0) {
return kCommunicationMemAlignSize;
}
return ((size + kCommunicationMemAlignSize - 1) / kCommunicationMemAlignSize) * kCommunicationMemAlignSize;
}
} // namespace
namespace mindspore::device::gpu {
GPUBucket::GPUBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size), collective_handle_(nullptr) {
group_ = kNcclWorldGroup;
}
void GPUBucket::AllocateAllReduceAddr() {
MS_LOG(INFO) << "start";
if (grad_tensor_list_.size() != bucket_size_) {
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
<< " is not equal to bucket size:" << bucket_size_;
}
auto total_size = 0;
std::vector<size_t> size_list;
std::vector<size_t> align_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor_type_list_.emplace_back(tensor->data_type());
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
MS_EXCEPTION_IF_NULL(device_address);
auto origin_size = device_address->GetSize();
auto align_size = AlignMemorySize(origin_size);
size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(
std::make_shared<kernel::Address>(static_cast<uint8_t *>(device_address->GetMutablePtr()), origin_size));
}
total_size_ = total_size;
ar_input_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
ar_output_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, size_list[i]));
memcpy_output += align_size_list[i];
}
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
}
MS_LOG(INFO) << "end";
}
void GPUBucket::FreeDeviceMem(void *dev_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(dev_ptr); }
void GPUBucket::FreeAllDeviceMem() {
MS_LOG(INFO) << "start";
if (ar_input_addr_ != nullptr) {
FreeDeviceMem(ar_input_addr_);
ar_input_addr_ = nullptr;
}
if (ar_output_addr_ != nullptr) {
FreeDeviceMem(ar_output_addr_);
ar_output_addr_ = nullptr;
}
MS_LOG(INFO) << "end";
}
void GPUBucket::CopyTensorToContiguousMemory() {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(compute_stream_);
// Clean allreduce input
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemsetAsync(ar_input_addr_, 0, total_size_, static_cast<cudaStream_t>(compute_stream_)),
"Call cudaMemsetAsync failed");
for (size_t i = 0; i < bucket_size_; ++i) {
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
if (!GPUDeviceManager::GetInstance().CopyDeviceMemToDeviceAsync(memcpy_output_addrs_[i]->addr,
memcpy_input_addrs_[i]->addr,
memcpy_output_addrs_[i]->size, compute_stream_)) {
MS_LOG(EXCEPTION) << "Copy memory failed";
}
}
MS_LOG(INFO) << "end";
}
void GPUBucket::LaunchAllReduce() {
MS_LOG(INFO) << "start";
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
auto all_reduce_funcptr =
reinterpret_cast<kernel::AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
MS_EXCEPTION_IF_NULL(stream_);
if (tensor_type_list_.empty()) {
MS_LOG(EXCEPTION) << "No tesnor type found";
}
auto type = tensor_type_list_[0];
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
[&type](TypeId tensor_type) { return type != tensor_type; })) {
MS_LOG(EXCEPTION) << "AllReduce input have different dtype";
}
auto type_size = abstract::TypeIdSize(type);
if (type_size == 0) {
MS_LOG(EXCEPTION) << "Invalid type:" << type;
}
// typeid to nccl_data_type
auto nccl_data_type_iter = kernel::kNcclDtypeMap.find(TypeIdLabel(type));
if (nccl_data_type_iter == kernel::kNcclDtypeMap.end()) {
MS_LOG(EXCEPTION) << "Invalid type:" << type;
}
auto nccl_result =
(*all_reduce_funcptr)(ar_input_addr_, ar_output_addr_, total_size_ / type_size, nccl_data_type_iter->second,
ncclRedOp_t::ncclSum, static_cast<cudaStream_t>(stream_), group_);
if (nccl_result != ncclSuccess) {
MS_LOG(EXCEPTION) << "AllReduce failed, ret:" << nccl_result;
}
MS_LOG(INFO) << "end";
}
void GPUBucket::Init() {
pre_event_ = std::make_shared<GpuEvent>();
post_event_ = std::make_shared<GpuEvent>();
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
MS_EXCEPTION_IF_NULL(kernel_runtime);
stream_ = kernel_runtime->communication_stream();
compute_stream_ = kernel_runtime->compute_stream();
MS_EXCEPTION_IF_NULL(pre_event_);
MS_EXCEPTION_IF_NULL(post_event_);
pre_event_->set_record_stream(compute_stream_);
pre_event_->set_wait_stream(stream_);
post_event_->set_record_stream(stream_);
post_event_->set_wait_stream(compute_stream_);
}
} // namespace mindspore::device::gpu

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#include "runtime/device/bucket.h"
namespace mindspore::device::gpu {
class GPUBucket : public Bucket {
public:
GPUBucket(uint32_t id, uint32_t bucket_size);
~GPUBucket() override = default;
void Init() override;
private:
void AllocateAllReduceAddr() override;
void FreeAllDeviceMem() override;
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
const void *collective_handle_;
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_event.h"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore::device::gpu {
GpuEvent::GpuEvent() {
auto ret = cudaEventCreate(&event_);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cudaEventCreate failed, ret:" << ret;
event_ = nullptr;
}
}
GpuEvent::~GpuEvent() { CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaEventDestroy(event_), "cudaEventDestory failed"); }
void GpuEvent::WaitEvent() {
MS_EXCEPTION_IF_NULL(wait_stream_);
MS_EXCEPTION_IF_NULL(event_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamWaitEvent(wait_stream_, event_, 0), "cudaStreamWaitEvent failed");
need_wait_ = false;
}
void GpuEvent::RecordEvent() {
MS_EXCEPTION_IF_NULL(event_);
MS_EXCEPTION_IF_NULL(record_stream_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(event_, record_stream_), "cudaEventRecord failed");
need_wait_ = true;
}
bool GpuEvent::NeedWait() { return need_wait_; }
} // namespace mindspore::device::gpu

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
#include <cuda_runtime_api.h>
#include "ir/device_event.h"
namespace mindspore::device::gpu {
class GpuEvent : public DeviceEvent {
public:
GpuEvent();
~GpuEvent() override;
void WaitEvent() override;
void RecordEvent() override;
bool NeedWait() override;
void set_wait_stream(void *wait_stream) override { wait_stream_ = static_cast<cudaStream_t>(wait_stream); }
void set_record_stream(void *record_stream) override { record_stream_ = static_cast<cudaStream_t>(record_stream); }
private:
cudaEvent_t event_{nullptr};
cudaStream_t wait_stream_{nullptr};
cudaStream_t record_stream_{nullptr};
bool need_wait_{false};
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_

View File

@ -229,6 +229,11 @@ bool GPUKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "No default CUDA stream found.";
return false;
}
GPUDeviceManager::GetInstance().CreateStream(&communication_stream_);
if (communication_stream_ == nullptr) {
MS_LOG(ERROR) << "Invalid communication stream";
return false;
}
return true;
}
@ -1251,6 +1256,7 @@ session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &n
return addr_iter->second[i];
}
} // namespace gpu
} // namespace device
} // namespace mindspore

View File

@ -47,6 +47,8 @@ class GPUKernelRuntime : public KernelRuntime {
bool Run(session::KernelGraph *graph, bool is_task_sink) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

View File

@ -946,11 +946,13 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
}
KernelLaunchProfiling(kernels[i]->fullname_with_scope());
}
}

View File

@ -33,6 +33,7 @@
#include "utils/ms_context.h"
#include "runtime/device/memory_manager.h"
#include "runtime/device/executor/dynamic_kernel.h"
#include "ir/device_event.h"
using mindspore::tensor::Tensor;
using std::vector;
@ -83,6 +84,9 @@ class KernelRuntime {
uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
return mem_manager_->MallocMem(type, size, address);
}
uint8_t *MallocCommunicationMemFromMemPool(size_t size) {
return mem_manager_->MallocCommunicationMemFromMemPool(size);
}
static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs);
@ -104,6 +108,8 @@ class KernelRuntime {
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); }
void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); }
virtual void *compute_stream() const { return nullptr; }
virtual void *communication_stream() const { return nullptr; }
protected:
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
@ -149,7 +155,8 @@ class KernelRuntime {
#if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Debugger> debugger_;
#endif
void *stream_ = nullptr;
void *stream_{nullptr};
void *communication_stream_{nullptr};
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_;
std::vector<std::shared_ptr<char[]>> buffer_ptrs_ = {};

View File

@ -106,6 +106,14 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_
return kernel_runtime.get();
}
KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
return GetKernelRuntime(device_name, device_id);
}
void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
std::string runtime_key = GetDeviceKey(device_name, device_id);
std::lock_guard<std::mutex> guard(lock_);

View File

@ -38,6 +38,7 @@ class KernelRuntimeManager {
}
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
KernelRuntime *GetCurrentKernelRuntime();
KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id);
void ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id);
void ClearRuntimeResource();

View File

@ -27,7 +27,7 @@ using mindspore::memreuse::MemReuseUtilPtr;
namespace mindspore {
namespace device {
size_t MemoryManager::GetCommonAlignSize(size_t input_size) const {
size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
}

View File

@ -53,13 +53,14 @@ class MemoryManager {
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
virtual void *MallocMemFromMemPool(size_t size);
virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
virtual void FreeMemFromMemPool(void *device_ptr);
virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size,
std::vector<size_t> size_list);
virtual std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list);
size_t GetCommonAlignSize(size_t input_size) const;
static size_t GetCommonAlignSize(size_t input_size);
size_t GetCommunicationAlignSize(size_t input_size) const;
protected:

View File

@ -273,6 +273,10 @@ constexpr auto kOneHotOpName = "OneHot";
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
// Communication world group
constexpr auto kNcclWorldGroup = "nccl_world_group";
constexpr auto kHcclWorldGroup = "hccl_world_group";
// Hcom Op Type
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
constexpr auto kHcomOpTypeAllGather = "HcomAllGather";
@ -331,12 +335,15 @@ constexpr auto kAttrSrTag = "sr_tag";
constexpr auto kAttrRootRank = "root_rank";
constexpr auto kAttrIsTraining = "is_training";
constexpr auto kAttrFusionId = "fusion_id";
constexpr auto kAttrBucketId = "bucket_id";
constexpr auto kAttrGradOutputIndex = "grad_output_index";
constexpr auto kAttrLabelIndex = "label_index";
constexpr auto kAttrLabelSwitchList = "label_switch_list";
constexpr auto kAttrNewAxisMask = "new_axis_mask";
constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask";
constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names";
constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop";
constexpr auto kAttrNeedRecordEvent = "need_record_event";
constexpr auto kAttrStreamId = "stream_id";
constexpr auto kAttrRecordEvent = "record_event";
constexpr auto kAttrWaitEvent = "wait_event";

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_DEVICE_EVENT_H
#define MINDSPORE_CORE_IR_DEVICE_EVENT_H
namespace mindspore {
class DeviceEvent {
public:
virtual ~DeviceEvent() = default;
virtual void WaitEvent() = 0;
virtual void RecordEvent() = 0;
virtual bool NeedWait() = 0;
virtual void set_wait_stream(void *stream) = 0;
virtual void set_record_stream(void *stream) = 0;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_DEVICE_EVENT_H

View File

@ -44,6 +44,7 @@ FuncGraph::FuncGraph()
kwonlyargs_count_(0),
hyper_param_count_(0),
is_generated_(false),
is_bprop_(false),
return_(nullptr),
manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false),

View File

@ -217,6 +217,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
void set_is_generate(bool generated) { is_generated_ = generated; }
bool is_generated() const { return is_generated_; }
void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
bool is_bprop() const { return is_bprop_; }
std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
@ -440,6 +442,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// the argument input list for the graph used to generate this graph
bool is_generated_;
bool is_bprop_;
// the cnode that calls 'return' primitive
// we use shared pointer to manage it.
CNodePtr return_;

View File

@ -247,6 +247,47 @@ class TensorDataImpl : public TensorData {
}
private:
void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
if (isScalar) {
ss << value;
} else {
// The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
const int width = std::is_same<T, float16>::value ? 11 : 15;
// The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
const int precision = std::is_same<T, float16>::value ? 4 : 8;
ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right)
<< value;
}
}
void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
if (isScalar) {
ss << (value ? "True" : "False");
} else {
constexpr int bool_max_width = sizeof("False") - 1;
ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
}
}
void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) const {
if (isScalar) {
ss << value;
} else {
// Add a padding string before the number, such as "###123", for subsequent replacement.
const int width = GetNumLength(value);
*max_width = std::max(*max_width, width);
std::string pad(width, '#');
ss << pad;
if constexpr (std::is_same<T, uint8_t>::value) {
ss << static_cast<uint16_t>(value);
} else if constexpr (std::is_same<T, int8_t>::value) {
ss << static_cast<int16_t>(value);
} else {
ss << value;
}
}
}
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma,
int *max_width) const {
const bool isScalar = ndim_ == 0 && end - start == 1;
@ -257,40 +298,11 @@ class TensorDataImpl : public TensorData {
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
const auto value = data_[cursor + i];
if constexpr (isFloat) {
if (isScalar) {
ss << value;
} else {
// The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
const int width = std::is_same<T, float16>::value ? 11 : 15;
// The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
const int precision = std::is_same<T, float16>::value ? 4 : 8;
ss << std::setw(width) << std::setprecision(precision)
<< std::setiosflags(std::ios::scientific | std::ios::right) << value;
}
OutputFloatDataString(ss, isScalar, value);
} else if (isBool) {
if (isScalar) {
ss << (value ? "True" : "False");
} else {
constexpr int bool_max_width = sizeof("False") - 1;
ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
}
OutputBoolDataString(ss, isScalar, value);
} else {
if (isScalar) {
ss << value;
} else {
// Add a padding string before the number, such as "###123", for subsequent replacement.
const int width = GetNumLength(value);
*max_width = std::max(*max_width, width);
std::string pad(width, '#');
ss << pad;
if constexpr (std::is_same<T, uint8_t>::value) {
ss << static_cast<uint16_t>(value);
} else if constexpr (std::is_same<T, int8_t>::value) {
ss << static_cast<int16_t>(value);
} else {
ss << value;
}
}
OutputOtherDataString(ss, isScalar, value, max_width);
}
if (!isScalar && i != end - 1) {
if (use_comma) {
@ -452,7 +464,8 @@ Tensor::Tensor(const Tensor &tensor)
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
@ -465,7 +478,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
padding_type_(tensor.padding_type()),
device_event_(tensor.device_event_) {}
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
@ -527,6 +541,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
event_ = tensor.event_;
sync_status_ = tensor.sync_status_;
padding_type_ = tensor.padding_type_;
device_event_ = tensor.device_event_;
}
return *this;
}

View File

@ -30,6 +30,7 @@
#include "base/float16.h"
#include "utils/shape_utils.h"
#include "utils/ms_exception.h"
#include "ir/device_event.h"
// brief mindspore namespace.
//
@ -326,6 +327,21 @@ class Tensor : public MetaTensor {
event_ = nullptr;
}
void SetDeviceEvent(const std::shared_ptr<DeviceEvent> &device_event) { device_event_ = device_event; }
void WaitDevice() {
if (device_event_ != nullptr) {
device_event_->WaitEvent();
}
}
bool NeedWaitDevice() const {
if (device_event_ != nullptr) {
return device_event_->NeedWait();
}
return false;
}
void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
TensorSyncStatus sync_status() const { return sync_status_; }
@ -352,6 +368,7 @@ class Tensor : public MetaTensor {
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
std::vector<Axis> padding_type_;
TypePtr cast_dtype_{nullptr};
std::shared_ptr<DeviceEvent> device_event_{nullptr};
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;

View File

@ -17,10 +17,11 @@ from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.common.tensor import RowTensor
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import functional as F, composite as C
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
@ -45,7 +46,7 @@ def _init_allreduce_operators(length, split_indices):
return op_list
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor")
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
"""
Apply allreduce on gradient.
@ -64,13 +65,33 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
if allreduce_filter:
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
return grad
return grad
@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor")
def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
"""
Apply allreduce on gradient in PyNative mode.
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
Args:
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce (Primitive): The communication operator for gradients.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
Returns:
Tensor, the gradient tensor after operation.
"""
if allreduce_filter:
if mean:
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
return grad
return grad
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
"""
Apply allreduce on gradient.
@ -93,15 +114,12 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter,
if allreduce_filter:
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
cast_op = P.Cast()
mul_op = P.Mul()
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
return grad
return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor")
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
"""
Apply allgather on gradient instead of allreduce for sparse feature.
@ -122,15 +140,12 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
indices = allgather(grad.indices)
dout = allgather(grad.values)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad.values))
cast_op = P.Cast()
mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
grad = RowTensor(indices, dout, grad.dense_shape)
return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
"""
Apply allgather on gradient instead of allreduce for sparse feature.
@ -155,10 +170,7 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred
indices = allgather(grad.indices)
dout = allgather(grad.values)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad.values))
cast_op = P.Cast()
mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
grad = RowTensor(indices, dout, grad.dense_shape)
return grad
@ -329,6 +341,7 @@ class DistributedGradReducer(Cell):
if not isinstance(degree, int) or degree <= 0:
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree
self.degree = Tensor(1.0 / self.degree, mstype.float32)
self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
@ -343,6 +356,7 @@ class DistributedGradReducer(Cell):
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
self.enable_parameter_server = any(self.ps_parameters)
self.mode = context.get_context("mode")
def construct(self, grads):
"""
@ -358,7 +372,9 @@ class DistributedGradReducer(Cell):
"""
datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
if self.split_fusion:
if self.mode == context.PYNATIVE_MODE:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
elif self.split_fusion:
if self.enable_parameter_server:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.op_list, self.allreduce_filter, grads, self.ps_parameters)

View File

@ -97,10 +97,13 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_event.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/signal_util.cc"

View File

@ -160,3 +160,7 @@ RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {
RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) {
return RT_ERROR_NONE;
}
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
return RT_ERROR_NONE;
}