forked from mindspore-Ecosystem/mindspore
PyNative AllReduce Bucket
This commit is contained in:
parent
d2d6c3cfb5
commit
171b468bb3
|
@ -64,6 +64,7 @@
|
|||
#include "toolchain/adx_datadump_server.h"
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
#include "debug/rdr/running_data_recorder.h"
|
||||
#include "runtime/device/ascend/ascend_bucket.h"
|
||||
#endif
|
||||
#if ENABLE_CPU && ENABLE_D
|
||||
#include "ps/util.h"
|
||||
|
@ -258,6 +259,7 @@ GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNode
|
|||
// construct graph, if successfully, graph_sum_ + 1
|
||||
auto graph = ConstructKernelGraph(lst, outputs);
|
||||
auto graph_id = graph->graph_id();
|
||||
InitAllBucket(graph);
|
||||
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
|
||||
return graph_id;
|
||||
}
|
||||
|
@ -632,6 +634,13 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
|
|||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// wait for allreduce
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWaitDevice()) {
|
||||
tensor->WaitDevice();
|
||||
}
|
||||
}
|
||||
// Run op
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1510,5 +1519,9 @@ void AscendSession::SyncStream() {
|
|||
MS_LOG(EXCEPTION) << "Sync stream error!";
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<device::Bucket> AscendSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
||||
return std::make_shared<device::ascend::AscendBucket>(bucket_id, bucket_size);
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,7 @@ class AscendSession : public SessionBasic {
|
|||
void BuildOpsInGraph(const GraphId &graph_id, const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, size_t> &cnode_refcount) override;
|
||||
std::string GetCommWorldGroup() override { return kHcclWorldGroup; }
|
||||
|
||||
private:
|
||||
// compile child graph when session have multiple child graphs
|
||||
|
@ -123,6 +124,7 @@ class AscendSession : public SessionBasic {
|
|||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
const std::map<KernelWithIndex, OutputTensorInfo> &node_output_info,
|
||||
InputTensorInfo *input_tensor_info);
|
||||
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
|
||||
// key is final_graph_id,value is child graph execute order of final graph
|
||||
std::unordered_map<GraphId, std::vector<GraphId>> graph_execute_orders_;
|
||||
// key is final_graph_id,value is the graph types of child graphs
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "backend/session/gpu_session.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
|
@ -63,6 +64,7 @@
|
|||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "runtime/device/gpu/gpu_bucket.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -394,6 +396,8 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) {
|
|||
manager->AddFuncGraph(graph);
|
||||
graph->set_manager(manager);
|
||||
}
|
||||
|
||||
InitAllBucket(graph);
|
||||
// Alloc memory in graph mode, including static memory and dynamic memory
|
||||
if (!pynative_mode) {
|
||||
AllocateMemory(graph.get());
|
||||
|
@ -473,6 +477,12 @@ void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info,
|
|||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask);
|
||||
EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
// wait for allreduce
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWaitDevice()) {
|
||||
tensor->WaitDevice();
|
||||
}
|
||||
}
|
||||
// run op
|
||||
auto kernel_graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -548,6 +558,10 @@ void GPUSession::SyncStream() {
|
|||
MS_LOG(EXCEPTION) << "Sync stream error!";
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<device::Bucket> GPUSession::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) {
|
||||
return std::make_shared<device::gpu::GPUBucket>(bucket_id, bucket_size);
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
|
@ -44,6 +45,8 @@ class GPUSession : public SessionBasic {
|
|||
const std::vector<int64_t> &tensors_mask) override;
|
||||
void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
|
||||
VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
|
||||
std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
|
||||
std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
|
||||
|
||||
private:
|
||||
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/common.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/constants.h"
|
||||
|
@ -556,10 +557,12 @@ void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<Kern
|
|||
void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
||||
const std::map<KernelWithIndex, std::vector<std::vector<size_t>>> &output_indexes,
|
||||
const std::map<KernelWithIndex, size_t> &ref_count,
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) {
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs,
|
||||
std::vector<TensorPtr> *runop_output_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(op_output_map);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(runop_output_tensors);
|
||||
auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
|
||||
if (output_tensors.size() > op_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
|
||||
|
@ -592,6 +595,7 @@ void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
|
|||
}
|
||||
BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
|
||||
tensor_ref = output_tensor;
|
||||
runop_output_tensors->emplace_back(output_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2196,6 +2200,11 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
GetRefCount(kernel_graph.get(), &cnode_refcount);
|
||||
BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
|
||||
|
||||
// Clear bucket resources every step
|
||||
if (kernel_graph->is_bprop()) {
|
||||
ClearAllBucket(graph_id);
|
||||
}
|
||||
|
||||
std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
|
||||
for (const auto &kernel : kernel_graph->execution_order()) {
|
||||
// Generate input tensors, tensor masks and input kernel with index
|
||||
|
@ -2212,9 +2221,15 @@ void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<
|
|||
RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
|
||||
input_tensor_info.input_tensors_mask);
|
||||
|
||||
std::vector<tensor::TensorPtr> new_output_tensors;
|
||||
|
||||
// Handle inputs and outputs of current op
|
||||
HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
|
||||
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs);
|
||||
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_refcount, &op_output_map, outputs, &new_output_tensors);
|
||||
// Save grad node to Bucket
|
||||
if (kernel_graph->is_bprop()) {
|
||||
AddGradAddrToBucket(graph_id, new_output_tensors);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
@ -2287,6 +2302,137 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
std::string group = GetCommWorldGroup();
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
// PyNative not support multi group allreduce
|
||||
group += "sum1";
|
||||
return parallel_context->GetAllReduceFusionSplitIndices(group);
|
||||
}
|
||||
|
||||
uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
|
||||
return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
|
||||
}
|
||||
|
||||
void SetGraphBpropAttr(const KernelGraphPtr &graph) {
|
||||
auto &execution_orders = graph->execution_order();
|
||||
if (std::any_of(execution_orders.begin(), execution_orders.end(),
|
||||
[](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
|
||||
graph->set_is_bprop(true);
|
||||
MS_LOG(INFO) << "Match bprop graph";
|
||||
} else {
|
||||
graph->set_is_bprop(false);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
|
||||
if (split_index.empty()) {
|
||||
auto grads_count = GetBpropGraphGradsCount(graph);
|
||||
if (grads_count == 0) {
|
||||
MS_LOG(EXCEPTION) << "Bprop graph has no grad";
|
||||
}
|
||||
return {grads_count};
|
||||
}
|
||||
|
||||
std::vector<uint32_t> bucket_size_list;
|
||||
uint32_t old_index = 0;
|
||||
for (auto &index : split_index) {
|
||||
if (old_index == 0) {
|
||||
bucket_size_list.emplace_back(index - old_index + 1);
|
||||
} else {
|
||||
bucket_size_list.emplace_back(index - old_index);
|
||||
}
|
||||
old_index = index;
|
||||
}
|
||||
return bucket_size_list;
|
||||
}
|
||||
|
||||
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
|
||||
SetGraphBpropAttr(graph);
|
||||
|
||||
if (!graph->is_bprop()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<device::Bucket>> bucket_list;
|
||||
// Create bucket for every split allreduce ops
|
||||
auto split_index = GetAllReduceSplitIndex();
|
||||
auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
|
||||
uint32_t bucket_id = 0;
|
||||
for (auto bucket_size : bucket_size_list) {
|
||||
MS_LOG(INFO) << "Create new bucket:" << bucket_id;
|
||||
auto bucket = CreateBucket(bucket_id++, bucket_size);
|
||||
bucket->Init();
|
||||
bucket_list.emplace_back(bucket);
|
||||
}
|
||||
|
||||
auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
|
||||
if (!bucket_ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
|
||||
}
|
||||
// set all free bucket index to 0
|
||||
auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
|
||||
if (!free_bucket_ret.second) {
|
||||
MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
|
||||
}
|
||||
MS_LOG(INFO) << "Init Bucket finish";
|
||||
}
|
||||
|
||||
void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode != parallel::DATA_PARALLEL) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto iter = bucket_map_.find(graph_id);
|
||||
if (iter == bucket_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
|
||||
}
|
||||
auto &bucket_list = iter->second;
|
||||
auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
|
||||
if (free_bucket_iter == free_bucket_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
|
||||
}
|
||||
|
||||
auto free_bucket_index = free_bucket_iter->second;
|
||||
for (auto &tensor : grad_tensor) {
|
||||
if (free_bucket_index >= bucket_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
|
||||
<< " total bucket num:" << bucket_list.size();
|
||||
}
|
||||
auto &free_bucket = bucket_list[free_bucket_index];
|
||||
free_bucket->AddGradTensor(tensor);
|
||||
if (free_bucket->full()) {
|
||||
MS_LOG(INFO) << "bucket is full";
|
||||
free_bucket->Launch();
|
||||
free_bucket_index = ++free_bucket_iter->second;
|
||||
MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
|
||||
auto iter = bucket_map_.find(graph_id);
|
||||
if (iter != bucket_map_.end()) {
|
||||
auto bucket_list = iter->second;
|
||||
for (auto &bucket : bucket_list) {
|
||||
MS_LOG(INFO) << "Clear bucket:" << bucket->id();
|
||||
bucket->Release();
|
||||
}
|
||||
}
|
||||
auto free_iter = free_bucket_id_map_.find(graph_id);
|
||||
if (free_iter != free_bucket_id_map_.end()) {
|
||||
free_iter->second = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
||||
if (!ps::PSContext::instance()->is_worker()) {
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "utils/contract.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/device/bucket.h"
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "debug/debugger/debugger.h"
|
||||
#endif
|
||||
|
@ -224,12 +225,20 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
void RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const;
|
||||
virtual std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) { return nullptr; }
|
||||
void InitAllBucket(const KernelGraphPtr &graph);
|
||||
void AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor);
|
||||
void ClearAllBucket(const GraphId &graph_id);
|
||||
std::vector<uint32_t> GetAllReduceSplitIndex();
|
||||
virtual std::string GetCommWorldGroup() { return std::string(); }
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const;
|
||||
void GetBatchElements(const AnfNodePtr &kernel_node) const;
|
||||
void InitPsWorker(const KernelGraphPtr &kernel_graph);
|
||||
#endif
|
||||
|
||||
std::map<uint32_t, std::vector<std::shared_ptr<device::Bucket>>> bucket_map_;
|
||||
std::map<uint32_t, uint32_t> free_bucket_id_map_;
|
||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;
|
||||
|
|
|
@ -677,34 +677,9 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
|
|||
return op_exec_info;
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(op_masks);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
|
||||
void PynativeExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim));
|
||||
|
||||
const auto &signature = prim->signatures();
|
||||
auto sig_size = signature.size();
|
||||
auto size = op_exec_info->op_inputs.size();
|
||||
|
||||
// ignore monad signature
|
||||
for (auto sig : signature) {
|
||||
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
|
||||
--sig_size;
|
||||
}
|
||||
}
|
||||
if (sig_size > 0 && sig_size != size) {
|
||||
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
|
||||
<< "inputs size " << sig_size;
|
||||
}
|
||||
if (op_exec_info->op_name != prim::kPrimCast->name()) {
|
||||
RunParameterAutoMixPrecisionCast(op_exec_info);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
|
||||
for (size_t i = 0; i < op_exec_info->op_inputs.size(); i++) {
|
||||
abstract::AbstractBasePtr abs = nullptr;
|
||||
const auto &obj = op_exec_info->op_inputs[i];
|
||||
|
@ -733,11 +708,42 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
|
|||
if (input_node->abstract() != nullptr) {
|
||||
abs = input_node->abstract();
|
||||
}
|
||||
inputs.emplace_back(input_node);
|
||||
inputs->emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
(*args_spec_list).emplace_back(CheckConstValue(prim, obj, abs, id, i));
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(op_masks);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(op_exec_info);
|
||||
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(prim));
|
||||
|
||||
const auto &signature = prim->signatures();
|
||||
auto sig_size = signature.size();
|
||||
auto size = op_exec_info->op_inputs.size();
|
||||
|
||||
// ignore monad signature
|
||||
for (auto sig : signature) {
|
||||
if (sig.default_value != nullptr && sig.default_value->isa<Monad>()) {
|
||||
--sig_size;
|
||||
}
|
||||
}
|
||||
if (sig_size > 0 && sig_size != size) {
|
||||
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires "
|
||||
<< "inputs size " << sig_size;
|
||||
}
|
||||
if (op_exec_info->op_name != prim::kPrimCast->name()) {
|
||||
RunParameterAutoMixPrecisionCast(op_exec_info);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get op " << op_exec_info->op_name << " grad_flag_ " << grad_flag();
|
||||
GetArgsSpec(op_exec_info, op_masks, &inputs, args_spec_list);
|
||||
|
||||
CNodePtr cnode = nullptr;
|
||||
if (need_construct_graph()) {
|
||||
|
|
|
@ -208,6 +208,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
PynativeStatusCode *const status);
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
|
||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
|
||||
"bucket.cc"
|
||||
)
|
||||
|
||||
if(ENABLE_GPU)
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ascend_bucket.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "runtime/mem.h"
|
||||
#include "external/hccl/hccl.h"
|
||||
#include "runtime/device/ascend/ascend_memory_pool.h"
|
||||
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
||||
#include "backend/kernel_compiler/hccl/hccl_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/ascend/ascend_event.h"
|
||||
#include "utils/profile.h"
|
||||
|
||||
#define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \
|
||||
{ \
|
||||
rtError_t ret = (expression); \
|
||||
if (ret != RT_ERROR_NONE) { \
|
||||
MS_LOG(EXCEPTION) << message << ", error code: " << ret; \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mindspore::device::ascend {
|
||||
void AscendBucket::AllocateAllReduceAddr() {
|
||||
// Check bucket is full
|
||||
if (grad_tensor_list_.size() != bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
|
||||
<< " is not equal to bucket size:" << bucket_size_;
|
||||
}
|
||||
|
||||
auto total_size = 0;
|
||||
std::vector<size_t> align_size_list;
|
||||
std::vector<size_t> origin_size_list;
|
||||
for (auto &tensor : grad_tensor_list_) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensor_type_list_.emplace_back(tensor->data_type());
|
||||
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
|
||||
auto origin_size = device_address->GetSize();
|
||||
auto align_size = MemoryManager::GetCommonAlignSize(origin_size);
|
||||
origin_size_list.emplace_back(origin_size);
|
||||
align_size_list.emplace_back(align_size);
|
||||
total_size += align_size;
|
||||
memcpy_input_addrs_.emplace_back(std::make_shared<kernel::Address>(
|
||||
static_cast<uint8_t *>(device_address->GetMutablePtr()), device_address->GetSize()));
|
||||
}
|
||||
|
||||
total_size_ = total_size;
|
||||
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
// AllReduce input output addr need to clear zero
|
||||
ar_input_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
|
||||
ar_output_addr_ = runtime_instance->MallocCommunicationMemFromMemPool(total_size);
|
||||
|
||||
// generate memecpy output addr
|
||||
uint8_t *memcpy_output = ar_input_addr_;
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, origin_size_list[i]));
|
||||
memcpy_output += align_size_list[i];
|
||||
}
|
||||
|
||||
// store output tensor addr
|
||||
uint8_t *tensor_output = ar_output_addr_;
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
new_tensor_output_addrs_.emplace_back(tensor_output);
|
||||
tensor_output += align_size_list[i];
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBucket::FreeDeviceMem(void *dev_ptr) { AscendMemoryPool::GetInstance().FreeTensorMem(dev_ptr); }
|
||||
|
||||
void AscendBucket::FreeAllDeviceMem() {
|
||||
if (ar_input_addr_ != nullptr) {
|
||||
uint8_t *origin_dev_addr = ar_input_addr_ - kMemAlignSize;
|
||||
FreeDeviceMem(origin_dev_addr);
|
||||
ar_input_addr_ = nullptr;
|
||||
}
|
||||
if (ar_output_addr_ != nullptr) {
|
||||
uint8_t *origin_dev_addr = ar_output_addr_ - kMemAlignSize;
|
||||
FreeDeviceMem(origin_dev_addr);
|
||||
ar_output_addr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBucket::CopyTensorToContiguousMemory() {
|
||||
// Clean input addr
|
||||
CHECK_ASCEND_RT_WITH_EXCEPTION(rtMemsetAsync(ar_input_addr_, total_size_, 0, total_size_, compute_stream_),
|
||||
"Call rtMemsetAsync failed");
|
||||
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
|
||||
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
|
||||
MS_LOG(DEBUG) << "MemcpyAsync dst size:" << memcpy_output_addrs_[i]->size
|
||||
<< " src size:" << memcpy_input_addrs_[i]->size;
|
||||
if (memcpy_output_addrs_[i]->size < memcpy_input_addrs_[i]->size) {
|
||||
MS_LOG(EXCEPTION) << "rtMemcpyAsync dst size < src size";
|
||||
}
|
||||
|
||||
CHECK_ASCEND_RT_WITH_EXCEPTION(
|
||||
rtMemcpyAsync(memcpy_output_addrs_[i]->addr, memcpy_output_addrs_[i]->size, memcpy_input_addrs_[i]->addr,
|
||||
memcpy_input_addrs_[i]->size, RT_MEMCPY_DEVICE_TO_DEVICE, compute_stream_),
|
||||
"Call rtMemcpyAsync failed");
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBucket::LaunchAllReduce() {
|
||||
if (tensor_type_list_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No tesnor type found";
|
||||
}
|
||||
|
||||
// AllReduce inputs data type should be same
|
||||
auto type = tensor_type_list_[0];
|
||||
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
|
||||
[&type](TypeId tensor_type) { return type != tensor_type; })) {
|
||||
MS_LOG(EXCEPTION) << "allreduce input have different dtype";
|
||||
}
|
||||
|
||||
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type);
|
||||
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
|
||||
MS_LOG(EXCEPTION) << "unknown data type:" << type;
|
||||
}
|
||||
|
||||
uint32_t type_size;
|
||||
if (!HcomUtil::GetHcomTypeSize(iter->second, &type_size)) {
|
||||
MS_LOG(EXCEPTION) << "get hcom type size fialed";
|
||||
}
|
||||
|
||||
if (type_size == 0 || total_size_ % type_size != 0) {
|
||||
MS_LOG(EXCEPTION) << "Total_size[" << total_size_ << "],Type_size[" << type_size << "] != 0, fail!";
|
||||
}
|
||||
auto hccl_count = total_size_ / type_size;
|
||||
|
||||
HcclReduceOp op_type = HcclReduceOp::HCCL_REDUCE_SUM;
|
||||
auto hccl_result = HcclAllReduce(ar_input_addr_, ar_output_addr_, hccl_count, iter->second, op_type,
|
||||
kernel::HcclContext::GetInstance().hccl_comm(), stream_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "HcclAllReduce faled, ret:" << hccl_result;
|
||||
}
|
||||
}
|
||||
|
||||
void AscendBucket::Init() {
|
||||
pre_event_ = std::make_shared<AscendEvent>();
|
||||
post_event_ = std::make_shared<AscendEvent>();
|
||||
|
||||
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
||||
compute_stream_ = kernel_runtime->compute_stream();
|
||||
stream_ = kernel_runtime->communication_stream();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(pre_event_);
|
||||
MS_EXCEPTION_IF_NULL(post_event_);
|
||||
pre_event_->set_wait_stream(stream_);
|
||||
pre_event_->set_record_stream(compute_stream_);
|
||||
post_event_->set_wait_stream(compute_stream_);
|
||||
post_event_->set_record_stream(stream_);
|
||||
}
|
||||
} // namespace mindspore::device::ascend
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
|
||||
|
||||
#include "runtime/device/bucket.h"
|
||||
|
||||
namespace mindspore::device::ascend {
|
||||
class AscendBucket : public Bucket {
|
||||
public:
|
||||
AscendBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size) {}
|
||||
~AscendBucket() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
private:
|
||||
void AllocateAllReduceAddr() override;
|
||||
void FreeAllDeviceMem() override;
|
||||
void FreeDeviceMem(void *dev_ptr) override;
|
||||
void CopyTensorToContiguousMemory() override;
|
||||
void LaunchAllReduce() override;
|
||||
};
|
||||
} // namespace mindspore::device::ascend
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/ascend_event.h"
|
||||
|
||||
#include "runtime/event.h"
|
||||
#include "runtime/stream.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::device::ascend {
|
||||
AscendEvent::AscendEvent() {
|
||||
auto ret = rtEventCreate(&event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "rtEventCreate failed, ret:" << ret;
|
||||
event_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
AscendEvent::~AscendEvent() {
|
||||
auto ret = rtEventDestroy(event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "rtEventDestory failed, ret:" << ret;
|
||||
}
|
||||
}
|
||||
|
||||
void AscendEvent::RecordEvent() {
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
MS_EXCEPTION_IF_NULL(record_stream_);
|
||||
auto ret = rtEventRecord(event_, record_stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "rtEventRecord failed, ret:" << ret;
|
||||
}
|
||||
need_wait_ = true;
|
||||
}
|
||||
|
||||
void AscendEvent::WaitEvent() {
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
MS_EXCEPTION_IF_NULL(wait_stream_);
|
||||
auto ret = rtStreamWaitEvent(wait_stream_, event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "rtStreamWaitEvent failed, ret:" << ret;
|
||||
}
|
||||
need_wait_ = false;
|
||||
}
|
||||
|
||||
bool AscendEvent::NeedWait() { return need_wait_; }
|
||||
} // namespace mindspore::device::ascend
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_ASCEND_EVENT_H
|
||||
#define MINDSPORE_ASCEND_EVENT_H
|
||||
|
||||
#include "runtime/base.h"
|
||||
#include "ir/device_event.h"
|
||||
namespace mindspore::device::ascend {
|
||||
class AscendEvent : public DeviceEvent {
|
||||
public:
|
||||
AscendEvent();
|
||||
~AscendEvent() override;
|
||||
|
||||
void WaitEvent() override;
|
||||
void RecordEvent() override;
|
||||
bool NeedWait() override;
|
||||
void set_wait_stream(rtStream_t wait_stream) override { wait_stream_ = wait_stream; }
|
||||
void set_record_stream(rtStream_t record_stream) override { record_stream_ = record_stream; }
|
||||
|
||||
private:
|
||||
rtEvent_t event_{nullptr};
|
||||
rtStream_t wait_stream_{nullptr};
|
||||
rtStream_t record_stream_{nullptr};
|
||||
bool need_wait_{false};
|
||||
};
|
||||
} // namespace mindspore::device::ascend
|
||||
#endif // MINDSPORE_ASCEND_EVENT_H
|
|
@ -718,6 +718,10 @@ bool AscendKernelRuntime::SyncStream() {
|
|||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
}
|
||||
if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
}
|
||||
FreeAndClearBufferPtrs();
|
||||
return true;
|
||||
}
|
||||
|
@ -786,6 +790,10 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
ret = rtStreamCreate(&communication_stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -799,6 +807,14 @@ bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
|||
stream_ = nullptr;
|
||||
}
|
||||
|
||||
if (communication_stream_ != nullptr) {
|
||||
auto ret = rtStreamDestroy(communication_stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
|
||||
}
|
||||
communication_stream_ = nullptr;
|
||||
}
|
||||
|
||||
auto ret = rtDeviceReset(device_id);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
|
||||
|
@ -919,4 +935,5 @@ uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
|
|||
auto ascend_mem_manager = dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
|
||||
return ascend_mem_manager->GetDeviceMemSize();
|
||||
}
|
||||
|
||||
} // namespace mindspore::device::ascend
|
||||
|
|
|
@ -57,6 +57,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
void *context() const override { return rt_context_; }
|
||||
void PreInit() override;
|
||||
uint64_t GetAvailableMemMaxSize() const override;
|
||||
void *compute_stream() const override { return stream_; }
|
||||
void *communication_stream() const override { return communication_stream_; }
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
|
|
@ -162,6 +162,14 @@ void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph *grap
|
|||
somas_reuse_util_ptr_->ConvertToProfilingNode(graph->graph_id());
|
||||
}
|
||||
}
|
||||
|
||||
// communication memory: [512align_size + data + 512align_size]
|
||||
// return the pointer to the start of data address.
|
||||
uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
|
||||
auto align_size = GetCommunicationAlignSize(size);
|
||||
uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
|
||||
return base_ptr + kMemAlignSize;
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ class AscendMemoryManager : public MemoryManager {
|
|||
void *MallocMemFromMemPool(size_t size) override;
|
||||
uint64_t GetDeviceMemSize();
|
||||
void MallocSomasDynamicMem(const session::KernelGraph *graph);
|
||||
uint8_t *MallocCommunicationMemFromMemPool(size_t size) override;
|
||||
|
||||
protected:
|
||||
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override;
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/bucket.h"
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/profile.h"
|
||||
|
||||
namespace mindspore::device {
|
||||
void Bucket::AddGradTensor(const tensor::TensorPtr &tensor) {
|
||||
if (grad_tensor_list_.size() >= bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "bucket is full";
|
||||
}
|
||||
grad_tensor_list_.emplace_back(tensor);
|
||||
if (grad_tensor_list_.size() > bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "too many tensor add to the bucket, bucket_size_:" << bucket_size_
|
||||
<< " total tensor size:" << grad_tensor_list_.size();
|
||||
}
|
||||
MS_LOG(INFO) << "current bucket tensors size:" << grad_tensor_list_.size();
|
||||
// bucket is full, start to launch allreduce
|
||||
if (grad_tensor_list_.size() == bucket_size_) {
|
||||
full_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Bucket::Launch() {
|
||||
auto start = GetTime();
|
||||
if (grad_tensor_list_.size() != bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "Bucket is not full, grad_tensor_list_ size:" << grad_tensor_list_.size()
|
||||
<< " bucket_size_:" << bucket_size_;
|
||||
}
|
||||
MS_LOG(INFO) << "Bucket is full, start to launch AllReduce";
|
||||
MS_EXCEPTION_IF_NULL(pre_event_);
|
||||
MS_EXCEPTION_IF_NULL(post_event_);
|
||||
AllocateAllReduceAddr();
|
||||
CopyTensorToContiguousMemory();
|
||||
pre_event_->RecordEvent();
|
||||
pre_event_->WaitEvent();
|
||||
LaunchAllReduce();
|
||||
post_event_->RecordEvent();
|
||||
UpdateTensorAddr();
|
||||
// pass event to the tensor
|
||||
for (auto &tensor : grad_tensor_list_) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensor->SetDeviceEvent(post_event_);
|
||||
}
|
||||
MS_LOG(INFO) << "Bucket launch cost:" << (GetTime() - start) * 1e6 << " us";
|
||||
}
|
||||
|
||||
// TODO(caifubi): float16 grad cast to float32 grad
|
||||
|
||||
void Bucket::UpdateTensorAddr() {
|
||||
if (grad_tensor_list_.size() != bucket_size_ || new_tensor_output_addrs_.size() != bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "grad_tensor_list size:" << grad_tensor_list_.size()
|
||||
<< " tensor output addr size:" << new_tensor_output_addrs_.size()
|
||||
<< " bucket size:" << bucket_size_;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
auto &tensor = grad_tensor_list_[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
|
||||
// release old addr and manage addr by this Bucket.
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
auto origin_dev_ptr = device_address->GetMutablePtr();
|
||||
// FreeDeviceMem(origin_dev_ptr);
|
||||
tensor_old_addr_list_.emplace_back(origin_dev_ptr);
|
||||
device_address->from_mem_pool_ = false;
|
||||
device_address->set_ptr(new_tensor_output_addrs_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Bucket::LazyDeleteOldAddr() {
|
||||
MS_LOG(INFO) << "Lazy delete old grad address";
|
||||
for (auto old_addr : tensor_old_addr_list_) {
|
||||
FreeDeviceMem(old_addr);
|
||||
}
|
||||
tensor_old_addr_list_.clear();
|
||||
}
|
||||
|
||||
void Bucket::Release() {
|
||||
MS_LOG(INFO) << "Clear bucket:" << id_;
|
||||
grad_tensor_list_.clear();
|
||||
new_tensor_output_addrs_.clear();
|
||||
memcpy_input_addrs_.clear();
|
||||
memcpy_output_addrs_.clear();
|
||||
tensor_type_list_.clear();
|
||||
LazyDeleteOldAddr();
|
||||
FreeAllDeviceMem();
|
||||
full_ = false;
|
||||
}
|
||||
} // namespace mindspore::device
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/device_event.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::device {
|
||||
class Bucket {
|
||||
public:
|
||||
Bucket(uint32_t id, uint32_t bucket_size)
|
||||
: id_(id),
|
||||
bucket_size_(bucket_size),
|
||||
full_(false),
|
||||
stream_(nullptr),
|
||||
compute_stream_(nullptr),
|
||||
pre_event_(nullptr),
|
||||
post_event_(nullptr),
|
||||
total_size_(0),
|
||||
ar_input_addr_(nullptr),
|
||||
ar_output_addr_(nullptr) {}
|
||||
virtual ~Bucket() = default;
|
||||
|
||||
uint32_t id() const { return id_; }
|
||||
bool full() const { return full_; }
|
||||
void Launch();
|
||||
void Release();
|
||||
void AddGradTensor(const tensor::TensorPtr &tensor);
|
||||
virtual void Init() = 0;
|
||||
|
||||
protected:
|
||||
uint32_t id_;
|
||||
uint32_t bucket_size_;
|
||||
bool full_;
|
||||
void *stream_;
|
||||
void *compute_stream_;
|
||||
|
||||
std::shared_ptr<DeviceEvent> pre_event_;
|
||||
std::shared_ptr<DeviceEvent> post_event_;
|
||||
|
||||
size_t total_size_;
|
||||
uint8_t *ar_input_addr_;
|
||||
uint8_t *ar_output_addr_;
|
||||
std::string group_;
|
||||
std::vector<tensor::TensorPtr> grad_tensor_list_;
|
||||
std::vector<uint8_t *> new_tensor_output_addrs_;
|
||||
std::vector<kernel::AddressPtr> memcpy_input_addrs_;
|
||||
std::vector<kernel::AddressPtr> memcpy_output_addrs_;
|
||||
std::vector<TypeId> tensor_type_list_;
|
||||
std::vector<void *> tensor_old_addr_list_;
|
||||
|
||||
virtual void AllocateAllReduceAddr() = 0;
|
||||
void UpdateTensorAddr();
|
||||
virtual void LaunchAllReduce() = 0;
|
||||
virtual void FreeAllDeviceMem() = 0;
|
||||
virtual void FreeDeviceMem(void *dev_ptr) = 0;
|
||||
virtual void CopyTensorToContiguousMemory() = 0;
|
||||
void LazyDeleteOldAddr();
|
||||
};
|
||||
} // namespace mindspore::device
|
||||
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_BUCKET_H_
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class Bucket;
|
||||
namespace cpu {
|
||||
class CPUSimpleMemPlan;
|
||||
class CPUMemoryManager;
|
||||
|
@ -100,6 +101,7 @@ class DeviceAddress : public mindspore::DeviceSync {
|
|||
friend class mindspore::device::ascend::AscendKernelRuntime;
|
||||
friend class mindspore::device::ascend::AscendMemoryManager;
|
||||
friend class mindspore::device::ascend::DataDumper;
|
||||
friend class mindspore::device::Bucket;
|
||||
};
|
||||
|
||||
using DeviceAddressPtr = std::shared_ptr<DeviceAddress>;
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/gpu/gpu_bucket.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nccl.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "abstract/utils.h"
|
||||
#include "runtime/device/gpu/gpu_event.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
#include "runtime/device/gpu/gpu_device_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
|
||||
namespace {
|
||||
const size_t kCommunicationMemAlignSize = 16;
|
||||
size_t AlignMemorySize(size_t size) {
|
||||
if (size == 0) {
|
||||
return kCommunicationMemAlignSize;
|
||||
}
|
||||
return ((size + kCommunicationMemAlignSize - 1) / kCommunicationMemAlignSize) * kCommunicationMemAlignSize;
|
||||
}
|
||||
} // namespace
|
||||
namespace mindspore::device::gpu {
|
||||
GPUBucket::GPUBucket(uint32_t id, uint32_t bucket_size) : Bucket(id, bucket_size), collective_handle_(nullptr) {
|
||||
group_ = kNcclWorldGroup;
|
||||
}
|
||||
|
||||
void GPUBucket::AllocateAllReduceAddr() {
|
||||
MS_LOG(INFO) << "start";
|
||||
if (grad_tensor_list_.size() != bucket_size_) {
|
||||
MS_LOG(EXCEPTION) << "grad tensor list size:" << grad_tensor_list_.size()
|
||||
<< " is not equal to bucket size:" << bucket_size_;
|
||||
}
|
||||
|
||||
auto total_size = 0;
|
||||
std::vector<size_t> size_list;
|
||||
std::vector<size_t> align_size_list;
|
||||
for (auto &tensor : grad_tensor_list_) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensor_type_list_.emplace_back(tensor->data_type());
|
||||
DeviceAddressPtr device_address = std::dynamic_pointer_cast<DeviceAddress>(tensor->device_address());
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
auto origin_size = device_address->GetSize();
|
||||
auto align_size = AlignMemorySize(origin_size);
|
||||
size_list.emplace_back(origin_size);
|
||||
align_size_list.emplace_back(align_size);
|
||||
total_size += align_size;
|
||||
memcpy_input_addrs_.emplace_back(
|
||||
std::make_shared<kernel::Address>(static_cast<uint8_t *>(device_address->GetMutablePtr()), origin_size));
|
||||
}
|
||||
total_size_ = total_size;
|
||||
|
||||
ar_input_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
|
||||
ar_output_addr_ = static_cast<uint8_t *>(GPUMemoryAllocator::GetInstance().AllocTensorMem(total_size));
|
||||
|
||||
uint8_t *memcpy_output = ar_input_addr_;
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, size_list[i]));
|
||||
memcpy_output += align_size_list[i];
|
||||
}
|
||||
|
||||
uint8_t *tensor_output = ar_output_addr_;
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
new_tensor_output_addrs_.emplace_back(tensor_output);
|
||||
tensor_output += align_size_list[i];
|
||||
}
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void GPUBucket::FreeDeviceMem(void *dev_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(dev_ptr); }
|
||||
|
||||
void GPUBucket::FreeAllDeviceMem() {
|
||||
MS_LOG(INFO) << "start";
|
||||
if (ar_input_addr_ != nullptr) {
|
||||
FreeDeviceMem(ar_input_addr_);
|
||||
ar_input_addr_ = nullptr;
|
||||
}
|
||||
if (ar_output_addr_ != nullptr) {
|
||||
FreeDeviceMem(ar_output_addr_);
|
||||
ar_output_addr_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void GPUBucket::CopyTensorToContiguousMemory() {
|
||||
MS_LOG(INFO) << "start";
|
||||
MS_EXCEPTION_IF_NULL(compute_stream_);
|
||||
// Clean allreduce input
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(ar_input_addr_, 0, total_size_, static_cast<cudaStream_t>(compute_stream_)),
|
||||
"Call cudaMemsetAsync failed");
|
||||
|
||||
for (size_t i = 0; i < bucket_size_; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(memcpy_output_addrs_[i]);
|
||||
MS_EXCEPTION_IF_NULL(memcpy_input_addrs_[i]);
|
||||
if (!GPUDeviceManager::GetInstance().CopyDeviceMemToDeviceAsync(memcpy_output_addrs_[i]->addr,
|
||||
memcpy_input_addrs_[i]->addr,
|
||||
memcpy_output_addrs_[i]->size, compute_stream_)) {
|
||||
MS_LOG(EXCEPTION) << "Copy memory failed";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void GPUBucket::LaunchAllReduce() {
|
||||
MS_LOG(INFO) << "start";
|
||||
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
auto all_reduce_funcptr =
|
||||
reinterpret_cast<kernel::AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
|
||||
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
|
||||
if (tensor_type_list_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "No tesnor type found";
|
||||
}
|
||||
auto type = tensor_type_list_[0];
|
||||
if (std::any_of(tensor_type_list_.begin(), tensor_type_list_.end(),
|
||||
[&type](TypeId tensor_type) { return type != tensor_type; })) {
|
||||
MS_LOG(EXCEPTION) << "AllReduce input have different dtype";
|
||||
}
|
||||
|
||||
auto type_size = abstract::TypeIdSize(type);
|
||||
if (type_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid type:" << type;
|
||||
}
|
||||
|
||||
// typeid to nccl_data_type
|
||||
auto nccl_data_type_iter = kernel::kNcclDtypeMap.find(TypeIdLabel(type));
|
||||
if (nccl_data_type_iter == kernel::kNcclDtypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid type:" << type;
|
||||
}
|
||||
|
||||
auto nccl_result =
|
||||
(*all_reduce_funcptr)(ar_input_addr_, ar_output_addr_, total_size_ / type_size, nccl_data_type_iter->second,
|
||||
ncclRedOp_t::ncclSum, static_cast<cudaStream_t>(stream_), group_);
|
||||
if (nccl_result != ncclSuccess) {
|
||||
MS_LOG(EXCEPTION) << "AllReduce failed, ret:" << nccl_result;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "end";
|
||||
}
|
||||
|
||||
void GPUBucket::Init() {
|
||||
pre_event_ = std::make_shared<GpuEvent>();
|
||||
post_event_ = std::make_shared<GpuEvent>();
|
||||
|
||||
auto kernel_runtime = KernelRuntimeManager::Instance().GetCurrentKernelRuntime();
|
||||
MS_EXCEPTION_IF_NULL(kernel_runtime);
|
||||
stream_ = kernel_runtime->communication_stream();
|
||||
compute_stream_ = kernel_runtime->compute_stream();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(pre_event_);
|
||||
MS_EXCEPTION_IF_NULL(post_event_);
|
||||
pre_event_->set_record_stream(compute_stream_);
|
||||
pre_event_->set_wait_stream(stream_);
|
||||
post_event_->set_record_stream(stream_);
|
||||
post_event_->set_wait_stream(compute_stream_);
|
||||
}
|
||||
} // namespace mindspore::device::gpu
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
|
||||
|
||||
#include "runtime/device/bucket.h"
|
||||
|
||||
namespace mindspore::device::gpu {
|
||||
class GPUBucket : public Bucket {
|
||||
public:
|
||||
GPUBucket(uint32_t id, uint32_t bucket_size);
|
||||
~GPUBucket() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
private:
|
||||
void AllocateAllReduceAddr() override;
|
||||
void FreeAllDeviceMem() override;
|
||||
void FreeDeviceMem(void *dev_ptr) override;
|
||||
void CopyTensorToContiguousMemory() override;
|
||||
void LaunchAllReduce() override;
|
||||
|
||||
const void *collective_handle_;
|
||||
};
|
||||
} // namespace mindspore::device::gpu
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/gpu/gpu_event.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
|
||||
namespace mindspore::device::gpu {
|
||||
GpuEvent::GpuEvent() {
|
||||
auto ret = cudaEventCreate(&event_);
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaEventCreate failed, ret:" << ret;
|
||||
event_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GpuEvent::~GpuEvent() { CHECK_CUDA_RET_WITH_ERROR_NOTRACE(cudaEventDestroy(event_), "cudaEventDestory failed"); }
|
||||
|
||||
void GpuEvent::WaitEvent() {
|
||||
MS_EXCEPTION_IF_NULL(wait_stream_);
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamWaitEvent(wait_stream_, event_, 0), "cudaStreamWaitEvent failed");
|
||||
need_wait_ = false;
|
||||
}
|
||||
|
||||
void GpuEvent::RecordEvent() {
|
||||
MS_EXCEPTION_IF_NULL(event_);
|
||||
MS_EXCEPTION_IF_NULL(record_stream_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(event_, record_stream_), "cudaEventRecord failed");
|
||||
need_wait_ = true;
|
||||
}
|
||||
|
||||
bool GpuEvent::NeedWait() { return need_wait_; }
|
||||
} // namespace mindspore::device::gpu
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
|
||||
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "ir/device_event.h"
|
||||
|
||||
namespace mindspore::device::gpu {
|
||||
class GpuEvent : public DeviceEvent {
|
||||
public:
|
||||
GpuEvent();
|
||||
~GpuEvent() override;
|
||||
|
||||
void WaitEvent() override;
|
||||
void RecordEvent() override;
|
||||
bool NeedWait() override;
|
||||
void set_wait_stream(void *wait_stream) override { wait_stream_ = static_cast<cudaStream_t>(wait_stream); }
|
||||
void set_record_stream(void *record_stream) override { record_stream_ = static_cast<cudaStream_t>(record_stream); }
|
||||
|
||||
private:
|
||||
cudaEvent_t event_{nullptr};
|
||||
cudaStream_t wait_stream_{nullptr};
|
||||
cudaStream_t record_stream_{nullptr};
|
||||
bool need_wait_{false};
|
||||
};
|
||||
} // namespace mindspore::device::gpu
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_EVENT_H_
|
|
@ -229,6 +229,11 @@ bool GPUKernelRuntime::InitDevice() {
|
|||
MS_LOG(ERROR) << "No default CUDA stream found.";
|
||||
return false;
|
||||
}
|
||||
GPUDeviceManager::GetInstance().CreateStream(&communication_stream_);
|
||||
if (communication_stream_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid communication stream";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1251,6 +1256,7 @@ session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &n
|
|||
|
||||
return addr_iter->second[i];
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,6 +47,8 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
bool Run(session::KernelGraph *graph, bool is_task_sink) override;
|
||||
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
|
||||
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
|
||||
void *compute_stream() const override { return stream_; }
|
||||
void *communication_stream() const override { return communication_stream_; }
|
||||
|
||||
protected:
|
||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
|
|
@ -946,11 +946,13 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
KernelLaunchProfiling(kernels[i]->fullname_with_scope());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "runtime/device/memory_manager.h"
|
||||
#include "runtime/device/executor/dynamic_kernel.h"
|
||||
#include "ir/device_event.h"
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
using std::vector;
|
||||
|
@ -83,6 +84,9 @@ class KernelRuntime {
|
|||
uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
|
||||
return mem_manager_->MallocMem(type, size, address);
|
||||
}
|
||||
uint8_t *MallocCommunicationMemFromMemPool(size_t size) {
|
||||
return mem_manager_->MallocCommunicationMemFromMemPool(size);
|
||||
}
|
||||
static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs);
|
||||
|
@ -104,6 +108,8 @@ class KernelRuntime {
|
|||
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
|
||||
void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); }
|
||||
void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); }
|
||||
virtual void *compute_stream() const { return nullptr; }
|
||||
virtual void *communication_stream() const { return nullptr; }
|
||||
|
||||
protected:
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||
|
@ -149,7 +155,8 @@ class KernelRuntime {
|
|||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
std::shared_ptr<Debugger> debugger_;
|
||||
#endif
|
||||
void *stream_ = nullptr;
|
||||
void *stream_{nullptr};
|
||||
void *communication_stream_{nullptr};
|
||||
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
|
||||
std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_;
|
||||
std::vector<std::shared_ptr<char[]>> buffer_ptrs_ = {};
|
||||
|
|
|
@ -106,6 +106,14 @@ KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_
|
|||
return kernel_runtime.get();
|
||||
}
|
||||
|
||||
KernelRuntime *KernelRuntimeManager::GetCurrentKernelRuntime() {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
uint32_t device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
return GetKernelRuntime(device_name, device_id);
|
||||
}
|
||||
|
||||
void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id) {
|
||||
std::string runtime_key = GetDeviceKey(device_name, device_id);
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
|
|
@ -38,6 +38,7 @@ class KernelRuntimeManager {
|
|||
}
|
||||
void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator);
|
||||
KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id);
|
||||
KernelRuntime *GetCurrentKernelRuntime();
|
||||
KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id);
|
||||
void ReleaseKernelRuntime(const std::string &device_name, uint32_t device_id);
|
||||
void ClearRuntimeResource();
|
||||
|
|
|
@ -27,7 +27,7 @@ using mindspore::memreuse::MemReuseUtilPtr;
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
size_t MemoryManager::GetCommonAlignSize(size_t input_size) const {
|
||||
size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
|
||||
return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,13 +53,14 @@ class MemoryManager {
|
|||
|
||||
virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
|
||||
virtual void *MallocMemFromMemPool(size_t size);
|
||||
virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
|
||||
virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
|
||||
virtual void FreeMemFromMemPool(void *device_ptr);
|
||||
virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size,
|
||||
std::vector<size_t> size_list);
|
||||
virtual std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list);
|
||||
|
||||
size_t GetCommonAlignSize(size_t input_size) const;
|
||||
static size_t GetCommonAlignSize(size_t input_size);
|
||||
size_t GetCommunicationAlignSize(size_t input_size) const;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -273,6 +273,10 @@ constexpr auto kOneHotOpName = "OneHot";
|
|||
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
|
||||
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
|
||||
|
||||
// Communication world group
|
||||
constexpr auto kNcclWorldGroup = "nccl_world_group";
|
||||
constexpr auto kHcclWorldGroup = "hccl_world_group";
|
||||
|
||||
// Hcom Op Type
|
||||
constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce";
|
||||
constexpr auto kHcomOpTypeAllGather = "HcomAllGather";
|
||||
|
@ -331,12 +335,15 @@ constexpr auto kAttrSrTag = "sr_tag";
|
|||
constexpr auto kAttrRootRank = "root_rank";
|
||||
constexpr auto kAttrIsTraining = "is_training";
|
||||
constexpr auto kAttrFusionId = "fusion_id";
|
||||
constexpr auto kAttrBucketId = "bucket_id";
|
||||
constexpr auto kAttrGradOutputIndex = "grad_output_index";
|
||||
constexpr auto kAttrLabelIndex = "label_index";
|
||||
constexpr auto kAttrLabelSwitchList = "label_switch_list";
|
||||
constexpr auto kAttrNewAxisMask = "new_axis_mask";
|
||||
constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask";
|
||||
constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names";
|
||||
constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop";
|
||||
constexpr auto kAttrNeedRecordEvent = "need_record_event";
|
||||
constexpr auto kAttrStreamId = "stream_id";
|
||||
constexpr auto kAttrRecordEvent = "record_event";
|
||||
constexpr auto kAttrWaitEvent = "wait_event";
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_IR_DEVICE_EVENT_H
|
||||
#define MINDSPORE_CORE_IR_DEVICE_EVENT_H
|
||||
|
||||
namespace mindspore {
|
||||
class DeviceEvent {
|
||||
public:
|
||||
virtual ~DeviceEvent() = default;
|
||||
virtual void WaitEvent() = 0;
|
||||
virtual void RecordEvent() = 0;
|
||||
virtual bool NeedWait() = 0;
|
||||
virtual void set_wait_stream(void *stream) = 0;
|
||||
virtual void set_record_stream(void *stream) = 0;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_DEVICE_EVENT_H
|
|
@ -44,6 +44,7 @@ FuncGraph::FuncGraph()
|
|||
kwonlyargs_count_(0),
|
||||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
is_bprop_(false),
|
||||
return_(nullptr),
|
||||
manager_(std::weak_ptr<FuncGraphManager>()),
|
||||
stub_(false),
|
||||
|
|
|
@ -217,6 +217,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
|
||||
void set_is_generate(bool generated) { is_generated_ = generated; }
|
||||
bool is_generated() const { return is_generated_; }
|
||||
void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
|
||||
bool is_bprop() const { return is_bprop_; }
|
||||
|
||||
std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
|
||||
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
|
@ -440,6 +442,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
// the argument input list for the graph used to generate this graph
|
||||
bool is_generated_;
|
||||
|
||||
bool is_bprop_;
|
||||
|
||||
// the cnode that calls 'return' primitive
|
||||
// we use shared pointer to manage it.
|
||||
CNodePtr return_;
|
||||
|
|
|
@ -247,6 +247,47 @@ class TensorDataImpl : public TensorData {
|
|||
}
|
||||
|
||||
private:
|
||||
void OutputFloatDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
|
||||
if (isScalar) {
|
||||
ss << value;
|
||||
} else {
|
||||
// The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
|
||||
const int width = std::is_same<T, float16>::value ? 11 : 15;
|
||||
// The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
|
||||
const int precision = std::is_same<T, float16>::value ? 4 : 8;
|
||||
ss << std::setw(width) << std::setprecision(precision) << std::setiosflags(std::ios::scientific | std::ios::right)
|
||||
<< value;
|
||||
}
|
||||
}
|
||||
|
||||
void OutputBoolDataString(std::ostringstream &ss, bool isScalar, const T &value) const {
|
||||
if (isScalar) {
|
||||
ss << (value ? "True" : "False");
|
||||
} else {
|
||||
constexpr int bool_max_width = sizeof("False") - 1;
|
||||
ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
|
||||
}
|
||||
}
|
||||
|
||||
void OutputOtherDataString(std::ostringstream &ss, bool isScalar, const T &value, int *max_width) const {
|
||||
if (isScalar) {
|
||||
ss << value;
|
||||
} else {
|
||||
// Add a padding string before the number, such as "###123", for subsequent replacement.
|
||||
const int width = GetNumLength(value);
|
||||
*max_width = std::max(*max_width, width);
|
||||
std::string pad(width, '#');
|
||||
ss << pad;
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
ss << static_cast<uint16_t>(value);
|
||||
} else if constexpr (std::is_same<T, int8_t>::value) {
|
||||
ss << static_cast<int16_t>(value);
|
||||
} else {
|
||||
ss << value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma,
|
||||
int *max_width) const {
|
||||
const bool isScalar = ndim_ == 0 && end - start == 1;
|
||||
|
@ -257,40 +298,11 @@ class TensorDataImpl : public TensorData {
|
|||
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
|
||||
const auto value = data_[cursor + i];
|
||||
if constexpr (isFloat) {
|
||||
if (isScalar) {
|
||||
ss << value;
|
||||
} else {
|
||||
// The placeholder of float16 is fixed at 11, while float/double is fixed at 15.
|
||||
const int width = std::is_same<T, float16>::value ? 11 : 15;
|
||||
// The printing precision of float16 is fixed at 4, while float/double is fixed at 8.
|
||||
const int precision = std::is_same<T, float16>::value ? 4 : 8;
|
||||
ss << std::setw(width) << std::setprecision(precision)
|
||||
<< std::setiosflags(std::ios::scientific | std::ios::right) << value;
|
||||
}
|
||||
OutputFloatDataString(ss, isScalar, value);
|
||||
} else if (isBool) {
|
||||
if (isScalar) {
|
||||
ss << (value ? "True" : "False");
|
||||
} else {
|
||||
constexpr int bool_max_width = sizeof("False") - 1;
|
||||
ss << std::setw(bool_max_width) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
|
||||
}
|
||||
OutputBoolDataString(ss, isScalar, value);
|
||||
} else {
|
||||
if (isScalar) {
|
||||
ss << value;
|
||||
} else {
|
||||
// Add a padding string before the number, such as "###123", for subsequent replacement.
|
||||
const int width = GetNumLength(value);
|
||||
*max_width = std::max(*max_width, width);
|
||||
std::string pad(width, '#');
|
||||
ss << pad;
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
ss << static_cast<uint16_t>(value);
|
||||
} else if constexpr (std::is_same<T, int8_t>::value) {
|
||||
ss << static_cast<int16_t>(value);
|
||||
} else {
|
||||
ss << value;
|
||||
}
|
||||
}
|
||||
OutputOtherDataString(ss, isScalar, value, max_width);
|
||||
}
|
||||
if (!isScalar && i != end - 1) {
|
||||
if (use_comma) {
|
||||
|
@ -452,7 +464,8 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
cache_enable_(tensor.cache_enable_),
|
||||
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
|
||||
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
padding_type_(tensor.padding_type()),
|
||||
device_event_(tensor.device_event_) {}
|
||||
|
||||
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
||||
: MetaTensor(data_type, tensor.shape_),
|
||||
|
@ -465,7 +478,8 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
|||
cache_enable_(tensor.cache_enable_),
|
||||
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
|
||||
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
padding_type_(tensor.padding_type()),
|
||||
device_event_(tensor.device_event_) {}
|
||||
|
||||
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
|
||||
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
|
||||
|
@ -527,6 +541,7 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
|||
event_ = tensor.event_;
|
||||
sync_status_ = tensor.sync_status_;
|
||||
padding_type_ = tensor.padding_type_;
|
||||
device_event_ = tensor.device_event_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "base/float16.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/ms_exception.h"
|
||||
#include "ir/device_event.h"
|
||||
|
||||
// brief mindspore namespace.
|
||||
//
|
||||
|
@ -326,6 +327,21 @@ class Tensor : public MetaTensor {
|
|||
event_ = nullptr;
|
||||
}
|
||||
|
||||
void SetDeviceEvent(const std::shared_ptr<DeviceEvent> &device_event) { device_event_ = device_event; }
|
||||
|
||||
void WaitDevice() {
|
||||
if (device_event_ != nullptr) {
|
||||
device_event_->WaitEvent();
|
||||
}
|
||||
}
|
||||
|
||||
bool NeedWaitDevice() const {
|
||||
if (device_event_ != nullptr) {
|
||||
return device_event_->NeedWait();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
|
||||
|
||||
TensorSyncStatus sync_status() const { return sync_status_; }
|
||||
|
@ -352,6 +368,7 @@ class Tensor : public MetaTensor {
|
|||
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
|
||||
std::vector<Axis> padding_type_;
|
||||
TypePtr cast_dtype_{nullptr};
|
||||
std::shared_ptr<DeviceEvent> device_event_{nullptr};
|
||||
};
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
|
||||
|
|
|
@ -17,10 +17,11 @@ from mindspore import context
|
|||
from mindspore.nn.cell import Cell
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.common.tensor import RowTensor
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
||||
|
@ -45,7 +46,7 @@ def _init_allreduce_operators(length, split_indices):
|
|||
return op_list
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor")
|
||||
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
|
||||
def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
||||
"""
|
||||
Apply allreduce on gradient.
|
||||
|
@ -64,13 +65,33 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
|||
if allreduce_filter:
|
||||
grad = allreduce(grad)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
|
||||
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
|
||||
return grad
|
||||
return grad
|
||||
|
||||
@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor")
|
||||
def _tensors_allreduce_post(degree, mean, allreduce_filter, grad):
|
||||
"""
|
||||
Apply allreduce on gradient in PyNative mode.
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
|
||||
Args:
|
||||
degree (int): The mean coefficient.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
|
||||
allgather (Primitive): The communication operator for sparse gradients.
|
||||
allreduce (Primitive): The communication operator for gradients.
|
||||
allreduce_filter (bool): When it is true, allreduce would apply.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
if allreduce_filter:
|
||||
if mean:
|
||||
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
|
||||
return grad
|
||||
return grad
|
||||
|
||||
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool")
|
||||
def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
||||
"""
|
||||
Apply allreduce on gradient.
|
||||
|
@ -93,15 +114,12 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter,
|
|||
if allreduce_filter:
|
||||
grad = allreduce(grad)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
|
||||
grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
|
||||
return grad
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor")
|
||||
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor")
|
||||
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
||||
"""
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
|
@ -122,15 +140,12 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
|
|||
indices = allgather(grad.indices)
|
||||
dout = allgather(grad.values)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
|
||||
grad = RowTensor(indices, dout, grad.dense_shape)
|
||||
return grad
|
||||
|
||||
|
||||
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
|
||||
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
|
||||
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
|
||||
"""
|
||||
Apply allgather on gradient instead of allreduce for sparse feature.
|
||||
|
@ -155,10 +170,7 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred
|
|||
indices = allgather(grad.indices)
|
||||
dout = allgather(grad.values)
|
||||
if mean:
|
||||
degree = F.scalar_cast(degree, F.dtype(grad.values))
|
||||
cast_op = P.Cast()
|
||||
mul_op = P.Mul()
|
||||
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
|
||||
dout = F.tensor_mul(dout, F.cast(degree, F.dtype(dout)))
|
||||
grad = RowTensor(indices, dout, grad.dense_shape)
|
||||
return grad
|
||||
|
||||
|
@ -329,6 +341,7 @@ class DistributedGradReducer(Cell):
|
|||
if not isinstance(degree, int) or degree <= 0:
|
||||
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
|
||||
self.degree = degree
|
||||
self.degree = Tensor(1.0 / self.degree, mstype.float32)
|
||||
self.mean = mean
|
||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
|
@ -343,6 +356,7 @@ class DistributedGradReducer(Cell):
|
|||
ps_filter = lambda x: x.is_param_ps
|
||||
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
||||
self.enable_parameter_server = any(self.ps_parameters)
|
||||
self.mode = context.get_context("mode")
|
||||
|
||||
def construct(self, grads):
|
||||
"""
|
||||
|
@ -358,7 +372,9 @@ class DistributedGradReducer(Cell):
|
|||
"""
|
||||
datatypes = self.map_(F.partial(_get_datatype), grads)
|
||||
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
if self.split_fusion:
|
||||
if self.mode == context.PYNATIVE_MODE:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
|
||||
elif self.split_fusion:
|
||||
if self.enable_parameter_server:
|
||||
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
|
||||
self.op_list, self.allreduce_filter, grads, self.ps_parameters)
|
||||
|
|
|
@ -97,10 +97,13 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_event.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/ascend/signal_util.cc"
|
||||
|
|
|
@ -160,3 +160,7 @@ RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback) {
|
|||
RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallback callback) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue