add op_mul fusion based on allreduce fusion in pynative mode

This commit is contained in:
lvchangquan 2021-03-13 14:23:15 +08:00
parent 64e83c1f26
commit 31f9e6a42c
24 changed files with 908 additions and 22 deletions

View File

@ -6,6 +6,7 @@ file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"executor.cc"
"executor_manager.cc"
"anf_runtime_algorithm.cc"
"single_kernel_graph.cc"
)
if("${ENABLE_HIDDEN}" STREQUAL "OFF")

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/session/single_kernel_graph.h"
#include <memory>
#include <string>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace session {
std::shared_ptr<session::KernelGraph> SingleKernelGraph::ConstructKernelGraphBasedOnSingleOp(
const std::string &op_name, const std::vector<TypeId> &input_dtypes, const std::vector<ShapeVector> &input_shapes,
const std::vector<TypeId> &output_dtypes, const std::vector<std::vector<size_t>> &output_shapes) {
auto graph = std::make_shared<session::KernelGraph>();
std::vector<AnfNodePtr> inputs;
// set input[0]
PrimitivePtr op_prim = std::make_shared<Primitive>(op_name);
MS_EXCEPTION_IF_NULL(op_prim);
inputs.push_back(std::make_shared<ValueNode>(op_prim));
// construct real input
if (input_dtypes.size() != input_shapes.size()) {
MS_LOG(EXCEPTION) << " input_dtypes size should equal to input_shapes size";
}
auto input_num = input_dtypes.size();
for (size_t i = 0; i < input_num; ++i) {
auto tensor = std::make_shared<tensor::Tensor>(input_dtypes[i], input_shapes[i]);
auto value_node = ConstructRunOpValueNode(graph, tensor);
inputs.push_back(value_node);
}
// obtain cnode
auto cnode = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode);
// get output dynamic shape info
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(false), cnode);
if (output_dtypes.size() != output_shapes.size()) {
MS_LOG(EXCEPTION) << " output_dtypes size should equal to output_shapes size";
}
AnfAlgo::SetOutputInferTypeAndShape(output_dtypes, output_shapes, cnode.get());
// set execution order
std::vector<CNodePtr> exe_order = {cnode};
graph->set_execution_order(exe_order);
// set graph output
graph->set_output(cnode);
graph->SetInputNodes();
return graph;
}
ValueNodePtr SingleKernelGraph::ConstructRunOpValueNode(const std::shared_ptr<session::KernelGraph> &graph,
const tensor::TensorPtr &input_tensor) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input_tensor);
auto value_node = std::make_shared<ValueNode>(input_tensor);
MS_EXCEPTION_IF_NULL(value_node);
// construct abstract of value node
auto type_of_tensor = input_tensor->Dtype();
auto shape_of_tensor = input_tensor->shape();
auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
value_node->set_abstract(abstract);
// add value node to graph
auto input_value_node = graph->NewValueNode(value_node);
graph->AddValueNodeToGraph(input_value_node);
return input_value_node;
}
} // namespace session
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_SINGLE_KERNEL_GRAPH_H_
#define MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_SINGLE_KERNEL_GRAPH_H_
#include <string>
#include <vector>
#include <memory>
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace session {
class SingleKernelGraph {
public:
SingleKernelGraph() = default;
~SingleKernelGraph() = default;
static std::shared_ptr<session::KernelGraph> ConstructKernelGraphBasedOnSingleOp(
const std::string &op_name, const std::vector<TypeId> &input_dtypes, const std::vector<ShapeVector> &input_shapes,
const std::vector<TypeId> &output_dtypes, const std::vector<std::vector<size_t>> &output_shapes);
private:
static ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<session::KernelGraph> &graph,
const tensor::TensorPtr &input_tensor);
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_SINGLE_KERNEL_GRAPH_H_

View File

@ -1,7 +1,7 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
"bucket.cc"
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
)
if("${ENABLE_HIDDEN}" STREQUAL "OFF")

View File

@ -26,6 +26,7 @@
#include "runtime/device/memory_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/ascend/ascend_event.h"
#include "runtime/device/ascend/ascend_launch_mul.h"
#include "utils/profile.h"
#define CHECK_ASCEND_RT_WITH_EXCEPTION(expression, message) \
@ -45,7 +46,6 @@ void AscendBucket::AllocateAllReduceAddr() {
}
auto total_size = 0;
std::vector<size_t> align_size_list;
std::vector<size_t> origin_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
@ -54,7 +54,7 @@ void AscendBucket::AllocateAllReduceAddr() {
auto origin_size = device_address->GetSize();
auto align_size = MemoryManager::GetCommonAlignSize(origin_size);
origin_size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
align_size_list_.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(std::make_shared<kernel::Address>(
static_cast<uint8_t *>(device_address->GetMutablePtr()), device_address->GetSize()));
@ -72,14 +72,7 @@ void AscendBucket::AllocateAllReduceAddr() {
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, origin_size_list[i]));
memcpy_output += align_size_list[i];
}
// store output tensor addr
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
memcpy_output += align_size_list_[i];
}
}
@ -96,6 +89,10 @@ void AscendBucket::FreeAllDeviceMem() {
FreeDeviceMem(origin_dev_addr);
ar_output_addr_ = nullptr;
}
// clear launch mul device Memory
if (launch_kernel != nullptr) {
launch_kernel->FreeLaunchDeviceMem();
}
}
void AscendBucket::CopyTensorToContiguousMemory() {
@ -154,6 +151,15 @@ void AscendBucket::LaunchAllReduce() {
}
}
std::shared_ptr<LaunchKernel> AscendBucket::CreateLaunchKernel() {
if (tensor_type_list_.empty()) {
MS_LOG(ERROR) << "tensor_type_list_ is empty";
}
auto launch_mul = std::make_shared<AscendLaunchMul>(stream_, tensor_type_list_[0], total_size_, ar_output_addr_);
MS_EXCEPTION_IF_NULL(launch_mul);
return launch_mul;
}
void AscendBucket::Init() {
pre_event_ = std::make_shared<AscendEvent>();
post_event_ = std::make_shared<AscendEvent>();

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_
#include <memory>
#include "runtime/device/bucket.h"
namespace mindspore::device::ascend {
@ -33,6 +34,7 @@ class AscendBucket : public Bucket {
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
std::shared_ptr<LaunchKernel> CreateLaunchKernel() override;
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_BUCKET_H_

View File

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_launch_kernel.h"
#include <vector>
#include <memory>
#include "runtime/device/memory_manager.h"
#include "runtime/device/ascend/ascend_memory_pool.h"
#include "runtime/device/ascend/kernel_build_ascend.h"
#include "runtime/device/ascend/kernel_select_ascend.h"
namespace mindspore::device::ascend {
void AscendLaunchKernel::FreeDeviceMem(void *addr) { AscendMemoryPool::GetInstance().FreeTensorMem(addr); }
size_t AscendLaunchKernel::AlignSizeForLaunchKernel(size_t size) { return MemoryManager::GetCommonAlignSize(size); }
uint8_t *AscendLaunchKernel::AllocDeviceMem(size_t size) {
auto device_memory = AscendMemoryPool::GetInstance().AllocTensorMem(size);
MS_EXCEPTION_IF_NULL(device_memory);
return static_cast<uint8_t *>(device_memory);
}
void AscendLaunchKernel::KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto node_list = kernel_graph->execution_order();
for (size_t i = 0; i < node_list.size(); ++i) {
auto status = device::ascend::SelectKernelInfo(node_list[i]);
if (status == ascend::kNoMatched) {
MS_LOG(ERROR) << "cnode name : " << node_list[i]->fullname_with_scope() << " kernel select failed";
}
}
}
void AscendLaunchKernel::KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
device::ascend::KernelBuild(kernel_graph->execution_order());
}
} // namespace mindspore::device::ascend

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_KERNEL_H_
#include <vector>
#include <memory>
#include "runtime/device/launch_kernel.h"
namespace mindspore::device::ascend {
class AscendLaunchKernel : public LaunchKernel {
public:
explicit AscendLaunchKernel(void *stream) : LaunchKernel(stream) {}
virtual ~AscendLaunchKernel() = default;
void FreeDeviceMem(void *addr) override;
size_t AlignSizeForLaunchKernel(size_t size) override;
uint8_t *AllocDeviceMem(size_t size) override;
void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void LaunchOpKernel() override = 0;
void FreeLaunchDeviceMem() override = 0;
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_KERNEL_H_

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ascend_launch_mul.h"
#include <memory>
#include "abstract/utils.h"
#include "runtime/mem.h"
#include "backend/session/single_kernel_graph.h"
#include "frontend/parallel/context.h"
namespace mindspore::device::ascend {
void AscendLaunchMul::FreeDeviceMem(void *addr) { AscendLaunchKernel::FreeDeviceMem(addr); }
size_t AscendLaunchMul::AlignSizeForLaunchKernel(size_t size) {
return AscendLaunchKernel::AlignSizeForLaunchKernel(size);
}
uint8_t *AscendLaunchMul::AllocDeviceMem(size_t size) { return AscendLaunchKernel::AllocDeviceMem(size); }
void AscendLaunchMul::KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) {
AscendLaunchKernel::KernelSelect(kernel_graph);
}
void AscendLaunchMul::KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) {
AscendLaunchKernel::KernelBuild(kernel_graph);
}
void AscendLaunchMul::LaunchOpKernel() {
kernel_mod_ = ObtainLaunchMulKernelMod();
MS_EXCEPTION_IF_NULL(kernel_mod_);
// construct mul inputs addr
ObtainMulInputsAddr();
// launch mul
LaunchSingleKernel(inputs_addr_);
}
void AscendLaunchMul::FreeLaunchDeviceMem() {
FreeInputDeviceMemory();
FreeOutputAndWorkspaceDeviceMem();
}
void AscendLaunchMul::CopyHostMemToDevice(size_t origin_size, size_t dst_size) {
auto ret = rtMemcpyAsync(input2_addr_, dst_size, &input2_value_, origin_size, RT_MEMCPY_HOST_TO_DEVICE, stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "launch rtMemcpyAsync failed, ret:" << ret;
}
}
} // namespace mindspore::device::ascend

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_MUL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_MUL_H_
#include <vector>
#include <memory>
#include "runtime/device/launch_mul.h"
#include "runtime/device/ascend/ascend_launch_kernel.h"
namespace mindspore::device::ascend {
class AscendLaunchMul : public AscendLaunchKernel, public LaunchMul {
public:
AscendLaunchMul(void *stream, TypeId dtype, size_t total_size, uint8_t *input1_addr)
: AscendLaunchKernel(stream), LaunchMul(dtype, total_size, input1_addr) {}
~AscendLaunchMul() override = default;
void FreeDeviceMem(void *addr) override;
size_t AlignSizeForLaunchKernel(size_t size) override;
uint8_t *AllocDeviceMem(size_t size) override;
void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void LaunchOpKernel() override;
void FreeLaunchDeviceMem() override;
void CopyHostMemToDevice(size_t origin_size, size_t dst_size) override;
};
} // namespace mindspore::device::ascend
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_LAUNCH_MUL_H_

View File

@ -18,6 +18,7 @@
#include <memory>
#include "runtime/device/kernel_runtime_manager.h"
#include "frontend/parallel/context.h"
#include "utils/profile.h"
namespace mindspore::device {
@ -51,6 +52,8 @@ void Bucket::Launch() {
pre_event_->RecordEvent();
pre_event_->WaitEvent();
LaunchAllReduce();
// mul fusion
CalculateMean();
post_event_->RecordEvent();
UpdateTensorAddr();
// pass event to the tensor
@ -84,6 +87,29 @@ void Bucket::UpdateTensorAddr() {
}
}
void Bucket::CalculateMean() {
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto grad_mean = parallel_context->gradients_mean();
if (!grad_mean) {
return;
}
launch_kernel = CreateLaunchKernel();
MS_EXCEPTION_IF_NULL(launch_kernel);
// launch mean
launch_kernel->LaunchOpKernel();
// store output tensor addr
auto launch_output = launch_kernel->GetKernelOutputAddr();
if (launch_output.size() != 1) {
MS_LOG(ERROR) << "launch mul outputs should have one output";
}
uint8_t *tensor_output = launch_output[0];
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list_[i];
}
}
void Bucket::LazyDeleteOldAddr() {
MS_LOG(INFO) << "Lazy delete old grad address";
for (auto old_addr : tensor_old_addr_list_) {
@ -95,6 +121,7 @@ void Bucket::LazyDeleteOldAddr() {
void Bucket::Release() {
MS_LOG(INFO) << "Clear bucket:" << id_;
grad_tensor_list_.clear();
align_size_list_.clear();
new_tensor_output_addrs_.clear();
memcpy_input_addrs_.clear();
memcpy_output_addrs_.clear();

View File

@ -23,6 +23,7 @@
#include <memory>
#include "ir/anf.h"
#include "ir/device_event.h"
#include "runtime/device/launch_kernel.h"
#include "runtime/device/device_address.h"
#include "backend/session/kernel_graph.h"
@ -37,6 +38,7 @@ class Bucket {
compute_stream_(nullptr),
pre_event_(nullptr),
post_event_(nullptr),
launch_kernel(nullptr),
total_size_(0),
ar_input_addr_(nullptr),
ar_output_addr_(nullptr) {}
@ -58,11 +60,13 @@ class Bucket {
std::shared_ptr<DeviceEvent> pre_event_;
std::shared_ptr<DeviceEvent> post_event_;
std::shared_ptr<LaunchKernel> launch_kernel;
size_t total_size_;
uint8_t *ar_input_addr_;
uint8_t *ar_output_addr_;
std::string group_;
std::vector<size_t> align_size_list_;
std::vector<tensor::TensorPtr> grad_tensor_list_;
std::vector<uint8_t *> new_tensor_output_addrs_;
std::vector<kernel::AddressPtr> memcpy_input_addrs_;
@ -72,6 +76,8 @@ class Bucket {
virtual void AllocateAllReduceAddr() = 0;
void UpdateTensorAddr();
void CalculateMean();
virtual std::shared_ptr<LaunchKernel> CreateLaunchKernel() = 0;
virtual void LaunchAllReduce() = 0;
virtual void FreeAllDeviceMem() = 0;
virtual void FreeDeviceMem(void *dev_ptr) = 0;

View File

@ -26,6 +26,7 @@
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "runtime/device/gpu/gpu_launch_mul.h"
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
#include "runtime/device/gpu/gpu_common.h"
@ -52,7 +53,6 @@ void GPUBucket::AllocateAllReduceAddr() {
auto total_size = 0;
std::vector<size_t> size_list;
std::vector<size_t> align_size_list;
for (auto &tensor : grad_tensor_list_) {
MS_EXCEPTION_IF_NULL(tensor);
tensor_type_list_.emplace_back(tensor->data_type());
@ -61,7 +61,7 @@ void GPUBucket::AllocateAllReduceAddr() {
auto origin_size = device_address->GetSize();
auto align_size = AlignMemorySize(origin_size);
size_list.emplace_back(origin_size);
align_size_list.emplace_back(align_size);
align_size_list_.emplace_back(align_size);
total_size += align_size;
memcpy_input_addrs_.emplace_back(
std::make_shared<kernel::Address>(static_cast<uint8_t *>(device_address->GetMutablePtr()), origin_size));
@ -74,13 +74,7 @@ void GPUBucket::AllocateAllReduceAddr() {
uint8_t *memcpy_output = ar_input_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
memcpy_output_addrs_.emplace_back(std::make_shared<kernel::Address>(memcpy_output, size_list[i]));
memcpy_output += align_size_list[i];
}
uint8_t *tensor_output = ar_output_addr_;
for (size_t i = 0; i < bucket_size_; ++i) {
new_tensor_output_addrs_.emplace_back(tensor_output);
tensor_output += align_size_list[i];
memcpy_output += align_size_list_[i];
}
MS_LOG(INFO) << "end";
}
@ -97,6 +91,10 @@ void GPUBucket::FreeAllDeviceMem() {
FreeDeviceMem(ar_output_addr_);
ar_output_addr_ = nullptr;
}
// clear launch mul device memory
if (launch_kernel != nullptr) {
launch_kernel->FreeLaunchDeviceMem();
}
MS_LOG(INFO) << "end";
}
@ -158,6 +156,15 @@ void GPUBucket::LaunchAllReduce() {
MS_LOG(INFO) << "end";
}
std::shared_ptr<LaunchKernel> GPUBucket::CreateLaunchKernel() {
if (tensor_type_list_.empty()) {
MS_LOG(ERROR) << "tensor_type_list_ is empty";
}
auto launch_mul = std::make_shared<GPULaunchMul>(stream_, tensor_type_list_[0], total_size_, ar_output_addr_);
MS_EXCEPTION_IF_NULL(launch_mul);
return launch_mul;
}
void GPUBucket::Init() {
pre_event_ = std::make_shared<GpuEvent>();
post_event_ = std::make_shared<GpuEvent>();

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_BUCKET_H_
#include <memory>
#include "runtime/device/bucket.h"
namespace mindspore::device::gpu {
@ -33,7 +34,7 @@ class GPUBucket : public Bucket {
void FreeDeviceMem(void *dev_ptr) override;
void CopyTensorToContiguousMemory() override;
void LaunchAllReduce() override;
std::shared_ptr<LaunchKernel> CreateLaunchKernel() override;
const void *collective_handle_;
};
} // namespace mindspore::device::gpu

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_launch_kernel.h"
#include <vector>
#include <memory>
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/gpu/kernel_info_setter.h"
#include "runtime/device/gpu/gpu_kernel_build.h"
#include "abstract/utils.h"
namespace {
constexpr size_t kCommunicationMemAlignSize = 16;
size_t AlignMemorySize(size_t size) {
if (size == 0) {
return kCommunicationMemAlignSize;
}
return ((size + kCommunicationMemAlignSize - 1) / kCommunicationMemAlignSize) * kCommunicationMemAlignSize;
}
} // namespace
namespace mindspore::device::gpu {
void GPULaunchkernel::FreeDeviceMem(void *addr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(addr); }
size_t GPULaunchkernel::AlignSizeForLaunchKernel(size_t size) { return AlignMemorySize(size); }
uint8_t *GPULaunchkernel::AllocDeviceMem(size_t size) {
auto device_memory = GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
MS_EXCEPTION_IF_NULL(device_memory);
return static_cast<uint8_t *>(device_memory);
}
void GPULaunchkernel::KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) {
auto node_list = kernel_graph->execution_order();
for (size_t i = 0; i < node_list.size(); ++i) {
device::gpu::SetKernelInfo(node_list[i]);
}
}
void GPULaunchkernel::KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) {
device::gpu::GpuBuild(kernel_graph);
}
} // namespace mindspore::device::gpu

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_KERNEL_H_
#include <vector>
#include <memory>
#include "runtime/device/launch_kernel.h"
namespace mindspore::device::gpu {
class GPULaunchkernel : public LaunchKernel {
public:
explicit GPULaunchkernel(void *stream) : LaunchKernel(stream) {}
virtual ~GPULaunchkernel() = default;
void FreeDeviceMem(void *addr) override;
size_t AlignSizeForLaunchKernel(size_t size) override;
uint8_t *AllocDeviceMem(size_t size) override;
void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void LaunchOpKernel() override = 0;
void FreeLaunchDeviceMem() override = 0;
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_KERNEL_H_

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/gpu/gpu_launch_mul.h"
#include <vector>
#include <memory>
#include "abstract/utils.h"
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "backend/session/single_kernel_graph.h"
#include "frontend/parallel/context.h"
namespace mindspore::device::gpu {
void GPULaunchMul::FreeDeviceMem(void *addr) { GPULaunchkernel::FreeDeviceMem(addr); }
size_t GPULaunchMul::AlignSizeForLaunchKernel(size_t size) { return GPULaunchkernel::AlignSizeForLaunchKernel(size); }
uint8_t *GPULaunchMul::AllocDeviceMem(size_t size) { return GPULaunchkernel::AllocDeviceMem(size); }
void GPULaunchMul::KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) {
GPULaunchkernel::KernelSelect(kernel_graph);
}
void GPULaunchMul::KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) {
GPULaunchkernel::KernelBuild(kernel_graph);
}
void GPULaunchMul::LaunchOpKernel() {
kernel_mod_ = ObtainLaunchMulKernelMod();
MS_EXCEPTION_IF_NULL(kernel_mod_);
// construct mul inputs addr
ObtainMulInputsAddr();
// launch mul
LaunchSingleKernel(inputs_addr_);
}
void GPULaunchMul::FreeLaunchDeviceMem() {
FreeInputDeviceMemory();
FreeOutputAndWorkspaceDeviceMem();
}
void GPULaunchMul::CopyHostMemToDevice(size_t origin_size, size_t) {
if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(input2_addr_, &input2_value_, origin_size, stream_)) {
MS_LOG(EXCEPTION) << "Copy memory failed";
}
}
} // namespace mindspore::device::gpu

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_MUL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_MUL_H_
#include <vector>
#include <memory>
#include "runtime/device/launch_mul.h"
#include "runtime/device/gpu/gpu_launch_kernel.h"
namespace mindspore::device::gpu {
class GPULaunchMul : public GPULaunchkernel, public LaunchMul {
public:
GPULaunchMul(void *stream, TypeId dtype, size_t total_size, uint8_t *input1_addr)
: GPULaunchkernel(stream), LaunchMul(dtype, total_size, input1_addr) {}
~GPULaunchMul() override = default;
void FreeDeviceMem(void *addr) override;
size_t AlignSizeForLaunchKernel(size_t size) override;
uint8_t *AllocDeviceMem(size_t size) override;
void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) override;
void LaunchOpKernel() override;
void FreeLaunchDeviceMem() override;
void CopyHostMemToDevice(size_t origin_size, size_t) override;
};
} // namespace mindspore::device::gpu
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_LAUNCH_MUL_H_

View File

@ -0,0 +1,107 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/launch_kernel.h"
#include <vector>
#include <memory>
namespace mindspore::device {
std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelAddress(const std::vector<size_t> &list,
std::vector<uint8_t *> *addr) {
std::vector<kernel::AddressPtr> kernel_address;
for (size_t i = 0; i < list.size(); ++i) {
auto size = AlignSizeForLaunchKernel(list[i]);
(*addr)[i] = AllocDeviceMem(size);
auto address = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(address);
address->addr = (*addr)[i];
MS_EXCEPTION_IF_NULL(address->addr);
address->size = size;
kernel_address.push_back(address);
}
return kernel_address;
}
std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelInputs(const std::vector<size_t> &inputs_list,
const std::vector<uint8_t *> &inputs_addr) {
std::vector<kernel::AddressPtr> kernel_inputs;
if (inputs_list.size() != inputs_addr.size()) {
MS_LOG(ERROR) << "input_list size should equal to input_addr_ size";
}
for (size_t i = 0; i < inputs_list.size(); ++i) {
auto input_size = AlignSizeForLaunchKernel(inputs_list[i]);
auto input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input);
input->addr = inputs_addr[i];
MS_EXCEPTION_IF_NULL(input->addr);
input->size = input_size;
kernel_inputs.push_back(input);
}
return kernel_inputs;
}
std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelOutputs(const std::vector<size_t> &outputs_list) {
// init output_addr_
outputs_addr_ = std::vector<uint8_t *>(outputs_list.size(), nullptr);
auto kernel_outputs = ObtainKernelAddress(outputs_list, &outputs_addr_);
return kernel_outputs;
}
std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelWorkspaces(const std::vector<size_t> &workspaces_list) {
std::vector<kernel::AddressPtr> kernel_workspace;
if (workspaces_list.empty()) {
return kernel_workspace;
}
// init workspace_addr_
workspaces_addr_ = std::vector<uint8_t *>(workspaces_list.size(), nullptr);
kernel_workspace = ObtainKernelAddress(workspaces_list, &workspaces_addr_);
return kernel_workspace;
}
void LaunchKernel::LaunchSingleKernel(const std::vector<uint8_t *> &inputs_addr) {
MS_EXCEPTION_IF_NULL(kernel_mod_);
// obtain kernel inputs
auto kernel_inputs = ObtainKernelInputs(kernel_mod_->GetInputSizeList(), inputs_addr);
// obtain kernel outputs
auto kernel_outputs = ObtainKernelOutputs(kernel_mod_->GetOutputSizeList());
// obtain kernel workspace
auto kernel_workspaces = ObtainKernelWorkspaces(kernel_mod_->GetWorkspaceSizeList());
// launch
auto ret_status = kernel_mod_->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret_status) {
MS_LOG(ERROR) << "Launch mul kernel failed.";
}
}
void LaunchKernel::FreeOutputAndWorkspaceDeviceMem() {
// free outputs_addr and workspaces_addr_
for (size_t i = 0; i < outputs_addr_.size(); ++i) {
if (outputs_addr_[i] != nullptr) {
FreeDeviceMem(outputs_addr_[i]);
outputs_addr_[i] = nullptr;
}
}
for (size_t i = 0; i < workspaces_addr_.size(); ++i) {
if (workspaces_addr_[i] != nullptr) {
FreeDeviceMem(workspaces_addr_[i]);
workspaces_addr_[i] = nullptr;
}
}
outputs_addr_.clear();
workspaces_addr_.clear();
}
} // namespace mindspore::device

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/session/kernel_graph.h"
namespace mindspore::device {
class LaunchKernel {
public:
explicit LaunchKernel(void *stream) : stream_(stream), kernel_mod_(nullptr) {}
virtual ~LaunchKernel() = default;
std::vector<uint8_t *> GetKernelOutputAddr() { return outputs_addr_; }
void LaunchSingleKernel(const std::vector<uint8_t *> &inputs_addr);
void FreeOutputAndWorkspaceDeviceMem();
virtual void FreeDeviceMem(void *addr) = 0;
virtual size_t AlignSizeForLaunchKernel(size_t size) = 0;
virtual uint8_t *AllocDeviceMem(size_t size) = 0;
virtual void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) = 0;
virtual void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) = 0;
virtual void LaunchOpKernel() = 0;
virtual void FreeLaunchDeviceMem() = 0;
protected:
void *stream_;
kernel::KernelMod *kernel_mod_;
std::vector<uint8_t *> outputs_addr_;
std::vector<uint8_t *> workspaces_addr_;
private:
std::vector<kernel::AddressPtr> ObtainKernelAddress(const std::vector<size_t> &list, std::vector<uint8_t *> *addr);
std::vector<kernel::AddressPtr> ObtainKernelInputs(const std::vector<size_t> &inputs_list,
const std::vector<uint8_t *> &inputs_addr);
std::vector<kernel::AddressPtr> ObtainKernelOutputs(const std::vector<size_t> &outputs_list);
std::vector<kernel::AddressPtr> ObtainKernelWorkspaces(const std::vector<size_t> &workspaces_list);
};
} // namespace mindspore::device
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_KERNEL_H_

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/launch_mul.h"
#include <vector>
#include <memory>
#include "abstract/utils.h"
#include "backend/session/single_kernel_graph.h"
#include "frontend/parallel/context.h"
namespace mindspore::device {
std::shared_ptr<session::KernelGraph> LaunchMul::ObtainMulKernelGraph() {
std::vector<TypeId> input_dtypes = {dtype_, dtype_};
std::vector<TypeId> output_dtypes = {dtype_};
// obtain input & output shapes
size_t dtype_size = abstract::TypeIdSize(dtype_);
int64_t shape = total_size_ / dtype_size;
std::vector<std::vector<int64_t>> input_shapes = {{shape}, {1}};
std::vector<std::vector<size_t>> output_shapes = {{static_cast<size_t>(shape)}};
auto mul_graph = session::SingleKernelGraph::ConstructKernelGraphBasedOnSingleOp(
kMulOpName, input_dtypes, input_shapes, output_dtypes, output_shapes);
MS_EXCEPTION_IF_NULL(mul_graph);
return mul_graph;
}
kernel::KernelMod *LaunchMul::ObtainLaunchMulKernelMod() {
if (mul_graph_ == nullptr) {
// construct mul kernel graph
mul_graph_ = ObtainMulKernelGraph();
MS_EXCEPTION_IF_NULL(mul_graph_);
// kernel select
KernelSelect(mul_graph_);
// kernel build
KernelBuild(mul_graph_);
}
// obtain kernel_mod
if (mul_graph_->execution_order().size() != 1) {
MS_LOG(ERROR) << "the execution order of the mul graph should have only one node";
}
return AnfAlgo::GetKernelMod(mul_graph_->execution_order()[0]);
}
void LaunchMul::ObtainMulInputsAddr() {
inputs_addr_.push_back(input1_addr_);
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto device_num = parallel_context->device_num();
if (device_num == 0) {
MS_LOG(ERROR) << "device num can't be zero";
}
input2_value_ = 1.0 / device_num;
auto size = abstract::TypeIdSize(dtype_);
auto input_size = AlignSizeForLaunchKernel(size * 1);
// alloc memory
input2_addr_ = AllocDeviceMem(input_size);
CopyHostMemToDevice(size, input_size);
inputs_addr_.push_back(input2_addr_);
}
void LaunchMul::FreeInputDeviceMemory() {
input1_addr_ = nullptr;
if (input2_addr_ != nullptr) {
FreeDeviceMem(input2_addr_);
input2_addr_ = nullptr;
}
inputs_addr_.clear();
}
} // namespace mindspore::device

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_MUL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_MUL_H_
#include <vector>
#include <memory>
#include "backend/session/kernel_graph.h"
namespace mindspore::device {
class LaunchMul {
public:
LaunchMul(TypeId dtype, size_t total_size, uint8_t *input1_addr)
: dtype_(dtype),
total_size_(total_size),
input1_addr_(input1_addr),
input2_addr_(nullptr),
input2_value_(0),
mul_graph_(nullptr) {}
virtual ~LaunchMul() = default;
virtual void FreeDeviceMem(void *addr) = 0;
virtual size_t AlignSizeForLaunchKernel(size_t size) = 0;
virtual uint8_t *AllocDeviceMem(size_t size) = 0;
virtual void KernelSelect(std::shared_ptr<session::KernelGraph> kernel_graph) = 0;
virtual void KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) = 0;
virtual void CopyHostMemToDevice(size_t origin_size, size_t dst_size) = 0;
std::shared_ptr<session::KernelGraph> ObtainMulKernelGraph();
kernel::KernelMod *ObtainLaunchMulKernelMod();
void ObtainMulInputsAddr();
void FreeInputDeviceMemory();
protected:
TypeId dtype_;
size_t total_size_;
uint8_t *input1_addr_;
uint8_t *input2_addr_;
float input2_value_;
std::shared_ptr<session::KernelGraph> mul_graph_;
std::vector<uint8_t *> inputs_addr_;
};
} // namespace mindspore::device
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LAUNCH_MUL_H_

View File

@ -373,7 +373,7 @@ class DistributedGradReducer(Cell):
datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
if self.mode == context.PYNATIVE_MODE:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean), self.allreduce_filter, grads)
new_grad = grads
elif self.split_fusion:
if self.enable_parameter_server:
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),

View File

@ -98,8 +98,11 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
"../../../mindspore/ccsrc/runtime/device/launch_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_launch_mul.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc"
"../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc"
"../../../mindspore/ccsrc/runtime/device/ascend/ascend_bucket.cc"