forked from mindspore-Ecosystem/mindspore
add ps cache
This commit is contained in:
parent
81f1283dd2
commit
a5f57ce8a0
|
@ -46,6 +46,7 @@ constexpr auto kTopKV2 = "TopKV2";
|
|||
constexpr auto kEditDistance = "EditDistance";
|
||||
constexpr auto kGatherD = "GatherD";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kUpdateCache = "UpdateCache";
|
||||
constexpr auto kCustRunApi = "RunCpuKernel";
|
||||
const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity};
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "hash_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
|
||||
int hash_index = swap_out_index[i];
|
||||
for (int j = 0; j < hash_dim; j++) {
|
||||
swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
|
||||
int hash_index = swap_in_index[i];
|
||||
for (int j = 0; j < hash_dim; j++) {
|
||||
hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream) {
|
||||
HashSwapOut<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_out_value, swap_out_index,
|
||||
index_size, hash_dim);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream) {
|
||||
HashSwapIn<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_in_value, swap_in_index,
|
||||
index_size, hash_dim);
|
||||
return;
|
||||
}
|
||||
|
||||
template void DoHashSwapOut<float>(const float *hash_table, float *swap_out_value, const int *swap_out_index,
|
||||
const int index_size, const int hash_dim, cudaStream_t cuda_stream);
|
||||
|
||||
template void DoHashSwapIn<float>(float *hash_table, const float *swap_in_value, const int *swap_in_index,
|
||||
const int index_size, const int hash_dim, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_KERNEL_HASH_IMPL_H_
|
|
@ -116,6 +116,8 @@ using KernelPackPtr = std::shared_ptr<KernelPack>;
|
|||
* @brief base class for autotensor kernel and cce kernel.
|
||||
*/
|
||||
struct Address {
|
||||
Address() {}
|
||||
Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {}
|
||||
void *addr;
|
||||
size_t size;
|
||||
};
|
||||
|
|
|
@ -16,5 +16,16 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
|
|||
list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc")
|
||||
endif ()
|
||||
|
||||
if (NOT ENABLE_D)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc")
|
||||
endif()
|
||||
|
||||
if (NOT ENABLE_GPU)
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
|
||||
endif()
|
||||
|
||||
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
|
||||
add_subdirectory(ps_cache)
|
||||
|
||||
set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES})
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))
|
||||
file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc")
|
||||
set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS)
|
||||
add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES})
|
||||
endif()
|
||||
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/ascend/ascend_ps_cache.h"
|
||||
#include <google/protobuf/text_format.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ps/ps_cache/ps_cache_factory.h"
|
||||
#include "runtime/device/ascend/ascend_memory_pool.h"
|
||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "proto/tensor.pb.h"
|
||||
#include "proto/tensor_shape.pb.h"
|
||||
#include "proto/attr.pb.h"
|
||||
#include "proto/node_def.pb.h"
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using AddressPtr = std::shared_ptr<Address>;
|
||||
using AddressPtrList = std::vector<AddressPtr>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace ascend {
|
||||
MS_REG_PS_CACHE(kAscendDevice, AscendPsCache);
|
||||
namespace {
|
||||
void SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
|
||||
mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
if (data_shape.size() != data_type.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type.";
|
||||
}
|
||||
for (size_t input_index = 0; input_index < data_shape.size(); input_index++) {
|
||||
::mindspore::Tensor *proto_inputs = proto->add_inputs();
|
||||
MS_EXCEPTION_IF_NULL(proto_inputs);
|
||||
auto input_shape = data_shape[input_index];
|
||||
mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape();
|
||||
MS_EXCEPTION_IF_NULL(tensorShape);
|
||||
for (auto item : input_shape) {
|
||||
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
||||
MS_EXCEPTION_IF_NULL(dim);
|
||||
dim->set_size((::google::protobuf::int64)item);
|
||||
}
|
||||
auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]);
|
||||
proto_inputs->set_tensor_type(input_type);
|
||||
proto_inputs->set_mem_device("HBM");
|
||||
}
|
||||
}
|
||||
|
||||
void SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
|
||||
mindspore::NodeDef *proto) {
|
||||
MS_EXCEPTION_IF_NULL(proto);
|
||||
if (data_shape.size() != data_type.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type.";
|
||||
}
|
||||
for (size_t output_index = 0; output_index < data_shape.size(); output_index++) {
|
||||
::mindspore::Tensor *proto_outputs = proto->add_outputs();
|
||||
MS_EXCEPTION_IF_NULL(proto_outputs);
|
||||
auto output_shape = data_shape[output_index];
|
||||
mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape();
|
||||
MS_EXCEPTION_IF_NULL(tensorShape);
|
||||
for (auto item : output_shape) {
|
||||
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
|
||||
MS_EXCEPTION_IF_NULL(dim);
|
||||
dim->set_size((::google::protobuf::int64)item);
|
||||
}
|
||||
auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]);
|
||||
proto_outputs->set_tensor_type(output_type);
|
||||
proto_outputs->set_mem_device("HBM");
|
||||
}
|
||||
}
|
||||
|
||||
void SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info,
|
||||
const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
mindspore::NodeDef proto;
|
||||
proto.set_op(op_info->op_name_);
|
||||
SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto);
|
||||
SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto);
|
||||
std::string nodeDefStr;
|
||||
if (!proto.SerializeToString(&nodeDefStr)) {
|
||||
MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_;
|
||||
kernel_mod_ptr->SetNodeDef(nodeDefStr);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AscendPsCache::InitDevice(uint32_t device_id, const void *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto ret = rtSetDevice(device_id);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]";
|
||||
}
|
||||
auto rt_context = const_cast<rtContext_t>(context);
|
||||
ret = rtCtxSetCurrent(rt_context);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
|
||||
}
|
||||
ret = rtStreamCreate(&stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
}
|
||||
|
||||
void *AscendPsCache::MallocMemory(size_t size) {
|
||||
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
|
||||
}
|
||||
|
||||
void AscendPsCache::MallocConstantMemory(size_t constant_value) {
|
||||
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
||||
MS_EXCEPTION_IF_NULL(offset_addr_);
|
||||
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
|
||||
cache_vocab_size_addr_ =
|
||||
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
||||
MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_);
|
||||
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
|
||||
}
|
||||
|
||||
void AscendPsCache::RecordEvent() {
|
||||
event_.reset(new rtEvent_t());
|
||||
auto ret = rtEventCreate(&(*event_));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Create event failed";
|
||||
}
|
||||
ret = rtEventRecord(*event_, stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Record event failed";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::SynchronizeEvent() {
|
||||
auto ret = rtEventSynchronize(*event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed";
|
||||
}
|
||||
ret = rtEventDestroy(*event_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::SynchronizeStream() {
|
||||
auto ret = rtStreamSynchronize(stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(dst);
|
||||
MS_EXCEPTION_IF_NULL(src);
|
||||
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(dst);
|
||||
MS_EXCEPTION_IF_NULL(src);
|
||||
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
||||
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) {
|
||||
MS_EXCEPTION_IF_NULL(hash_table_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_value_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_index_addr);
|
||||
auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>();
|
||||
MS_EXCEPTION_IF_NULL(hash_swap_out_mod);
|
||||
hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName);
|
||||
std::vector<std::vector<size_t>> input_shape;
|
||||
std::vector<std::vector<size_t>> output_shape;
|
||||
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32};
|
||||
std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
|
||||
input_shape.push_back({hash_table_size, embedding_size});
|
||||
input_shape.push_back({swap_out_size});
|
||||
input_shape.push_back({1});
|
||||
output_shape.push_back({swap_out_size, embedding_size});
|
||||
auto op_info =
|
||||
std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type);
|
||||
SetNodedefProto(op_info, hash_swap_out_mod);
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_outputs = {
|
||||
std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))};
|
||||
AddressPtrList kernel_workspaces;
|
||||
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
|
||||
kernel_inputs.push_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int)));
|
||||
kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
|
||||
auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Hash swap out launch failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
||||
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) {
|
||||
MS_EXCEPTION_IF_NULL(hash_table_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_in_value_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_in_index_addr);
|
||||
auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>();
|
||||
MS_EXCEPTION_IF_NULL(hash_swap_in_mod);
|
||||
hash_swap_in_mod->SetNodeName(kernel::kUpdateCache);
|
||||
std::vector<std::vector<size_t>> input_shape;
|
||||
std::vector<std::vector<size_t>> output_shape;
|
||||
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32,
|
||||
TypeId::kNumberTypeInt32};
|
||||
std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
|
||||
input_shape.push_back({hash_table_size, embedding_size});
|
||||
input_shape.push_back({swap_in_size});
|
||||
input_shape.push_back({swap_in_size, embedding_size});
|
||||
input_shape.push_back({1});
|
||||
output_shape.push_back({1});
|
||||
auto op_info =
|
||||
std::make_shared<KernelNodeInfo>(kernel::kUpdateCache, input_shape, input_type, output_shape, output_type);
|
||||
SetNodedefProto(op_info, hash_swap_in_mod);
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_outputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
|
||||
kernel_inputs.push_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int)));
|
||||
kernel_inputs.push_back(std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float)));
|
||||
kernel_inputs.push_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int)));
|
||||
// The output of updateCache kernel is required but not useful, so any address can be assigned.
|
||||
kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
|
||||
auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Hash swap in launch failed.";
|
||||
}
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h"
|
||||
#include "ir/dtype.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace ascend {
|
||||
struct KernelNodeInfo {
|
||||
KernelNodeInfo(const std::string &op_name, std::vector<std::vector<size_t>> input_data_shape,
|
||||
std::vector<TypeId> input_data_type, std::vector<std::vector<size_t>> output_data_shape,
|
||||
std::vector<TypeId> output_data_type)
|
||||
: op_name_(op_name) {
|
||||
input_data_shape_.swap(input_data_shape);
|
||||
input_data_type_.swap(input_data_type);
|
||||
output_data_shape_.swap(output_data_shape);
|
||||
output_data_type_.swap(output_data_type);
|
||||
}
|
||||
std::string op_name_;
|
||||
std::vector<std::vector<size_t>> input_data_shape_;
|
||||
std::vector<TypeId> input_data_type_;
|
||||
std::vector<std::vector<size_t>> output_data_shape_;
|
||||
std::vector<TypeId> output_data_type_;
|
||||
};
|
||||
|
||||
class AscendPsCache : public PsCacheBasic {
|
||||
public:
|
||||
AscendPsCache() = default;
|
||||
~AscendPsCache() override = default;
|
||||
void InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
void MallocConstantMemory(size_t constant_value) override;
|
||||
void RecordEvent() override;
|
||||
void SynchronizeEvent() override;
|
||||
void SynchronizeStream() override;
|
||||
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
||||
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
||||
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
|
||||
size_t embedding_size, size_t swap_out_size) override;
|
||||
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
|
||||
size_t embedding_size, size_t swap_in_size) override;
|
||||
|
||||
private:
|
||||
int *offset_addr_{nullptr};
|
||||
int *cache_vocab_size_addr_{nullptr};
|
||||
std::unique_ptr<rtEvent_t> event_;
|
||||
};
|
||||
} // namespace ascend
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_ASCEND_ASCEND_PS_CACHE_H_
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/gpu/gpu_ps_cache.h"
|
||||
#include "ps/ps_cache/ps_cache_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace gpu {
|
||||
MS_REG_PS_CACHE(kGPUDevice, GPUPsCache);
|
||||
void GPUPsCache::InitDevice(uint32_t device_id, const void *) {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
|
||||
"Cuda create stream failed");
|
||||
}
|
||||
|
||||
void *GPUPsCache::MallocMemory(size_t size) {
|
||||
return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
|
||||
}
|
||||
|
||||
void GPUPsCache::RecordEvent() {
|
||||
event_.reset(new cudaEvent_t());
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda record event failed");
|
||||
}
|
||||
|
||||
void GPUPsCache::SynchronizeEvent() {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
|
||||
}
|
||||
|
||||
void GPUPsCache::SynchronizeStream() {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda sync stream failed");
|
||||
}
|
||||
|
||||
void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(dst);
|
||||
MS_EXCEPTION_IF_NULL(src);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda memcpy failed");
|
||||
}
|
||||
|
||||
void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
|
||||
MS_EXCEPTION_IF_NULL(dst);
|
||||
MS_EXCEPTION_IF_NULL(src);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda memcpy failed");
|
||||
}
|
||||
|
||||
void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
|
||||
size_t embedding_size, size_t swap_out_size) {
|
||||
MS_EXCEPTION_IF_NULL(hash_table_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_value_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_index_addr);
|
||||
DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr),
|
||||
reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size,
|
||||
reinterpret_cast<cudaStream_t>(stream_));
|
||||
}
|
||||
|
||||
void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
|
||||
size_t embedding_size, size_t swap_in_size) {
|
||||
MS_EXCEPTION_IF_NULL(hash_table_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_in_value_addr);
|
||||
MS_EXCEPTION_IF_NULL(swap_in_index_addr);
|
||||
DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr),
|
||||
reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size,
|
||||
reinterpret_cast<cudaStream_t>(stream_));
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <memory>
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace gpu {
|
||||
class GPUPsCache : public PsCacheBasic {
|
||||
public:
|
||||
GPUPsCache() = default;
|
||||
~GPUPsCache() override = default;
|
||||
void InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
void RecordEvent() override;
|
||||
void SynchronizeEvent() override;
|
||||
void SynchronizeStream() override;
|
||||
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
||||
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
||||
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
|
||||
size_t embedding_size, size_t swap_out_size) override;
|
||||
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
|
||||
size_t embedding_size, size_t swap_in_size) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cudaEvent_t> event_;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class PsCacheBasic {
|
||||
public:
|
||||
PsCacheBasic() = default;
|
||||
virtual ~PsCacheBasic() = default;
|
||||
virtual void InitDevice(uint32_t device_id, const void *context) = 0;
|
||||
virtual void *MallocMemory(size_t size) = 0;
|
||||
virtual void MallocConstantMemory(size_t constant_value) {}
|
||||
virtual void RecordEvent() = 0;
|
||||
virtual void SynchronizeEvent() = 0;
|
||||
virtual void SynchronizeStream() = 0;
|
||||
virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
|
||||
virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
|
||||
virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
||||
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0;
|
||||
virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
||||
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0;
|
||||
|
||||
protected:
|
||||
void *stream_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_BASIC_H
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/ps_cache_factory.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
PsCacheFactory &PsCacheFactory::Get() {
|
||||
static PsCacheFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void PsCacheFactory::Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator) {
|
||||
if (ps_cache_creators_.end() == ps_cache_creators_.find(device_name)) {
|
||||
(void)ps_cache_creators_.emplace(device_name, ps_cache_creator);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<PsCacheBasic> PsCacheFactory::ps_cache(const std::string &device_name) {
|
||||
auto iter = ps_cache_creators_.find(device_name);
|
||||
if (ps_cache_creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
using PsCacheCreator = std::function<std::shared_ptr<PsCacheBasic>()>;
|
||||
class PsCacheFactory {
|
||||
public:
|
||||
static PsCacheFactory &Get();
|
||||
void Register(const std::string &device_name, PsCacheCreator &&ps_cache_creator);
|
||||
std::shared_ptr<PsCacheBasic> ps_cache(const std::string &device_name);
|
||||
|
||||
private:
|
||||
PsCacheFactory() = default;
|
||||
~PsCacheFactory() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(PsCacheFactory)
|
||||
std::map<std::string, PsCacheCreator> ps_cache_creators_;
|
||||
};
|
||||
|
||||
class PsCacheRegistrar {
|
||||
public:
|
||||
PsCacheRegistrar(const std::string &device_name, PsCacheCreator &&ps_cache_creator) {
|
||||
PsCacheFactory::Get().Register(device_name, std::move(ps_cache_creator));
|
||||
}
|
||||
~PsCacheRegistrar() = default;
|
||||
};
|
||||
|
||||
#define MS_REG_PS_CACHE(DEVICE_NAME, PS_CACHE_CLASS) \
|
||||
static const PsCacheRegistrar g_ps_cache_registrar__##DEVICE_NAME##_##_reg( \
|
||||
DEVICE_NAME, []() { return std::make_shared<PS_CACHE_CLASS>(); });
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_FACTORY_H_
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/ps_data/ps_data_channel.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
void PsDataChannel::TryLockChannel() {
|
||||
// The prefetch order of data needs to be consistent with the graph execution order.
|
||||
// Example: if graph execution order is graph1 --> graph2 --> graph1 -->graph2,
|
||||
// then the data prefetch order needs be channel1 --> channel 2 --> channel1 --> channel2.
|
||||
if ((current_data_step_ != 0) && (current_data_step_ % step_num_ == 0)) {
|
||||
MS_LOG(INFO) << "Lock channel:" << channel_name_;
|
||||
std::unique_lock<std::mutex> locker(channel_mutex_);
|
||||
channel_.wait(locker, [this] { return channel_open_; });
|
||||
channel_open_ = false;
|
||||
}
|
||||
current_data_step_++;
|
||||
}
|
||||
|
||||
void PsDataChannel::TryWakeChannel() {
|
||||
if ((current_graph_step_ != 0) && (current_graph_step_ % step_num_ == 0)) {
|
||||
MS_LOG(INFO) << "Wake up channel:" << channel_name_;
|
||||
std::lock_guard<std::mutex> locker(channel_mutex_);
|
||||
channel_open_ = true;
|
||||
channel_.notify_one();
|
||||
}
|
||||
current_graph_step_++;
|
||||
}
|
||||
|
||||
void PsDataChannel::set_data(void *data, const size_t data_size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
TryLockChannel();
|
||||
data_ = data;
|
||||
data_size_ = data_size;
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <condition_variable>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class PsDataChannel {
|
||||
public:
|
||||
PsDataChannel(const std::string &channel_name, size_t step_num)
|
||||
: channel_name_(channel_name),
|
||||
step_num_(step_num),
|
||||
current_data_step_(0),
|
||||
current_graph_step_(0),
|
||||
channel_open_(false),
|
||||
data_(nullptr),
|
||||
data_size_(0) {}
|
||||
virtual ~PsDataChannel() = default;
|
||||
void set_data(void *data, const size_t data_size);
|
||||
void *data() const { return data_; }
|
||||
size_t data_size() const { return data_size_; }
|
||||
void ResetData() { data_ = nullptr; }
|
||||
void set_step_num(size_t step_num) { step_num_ = step_num; }
|
||||
void TryWakeChannel();
|
||||
|
||||
private:
|
||||
void TryLockChannel();
|
||||
std::string channel_name_;
|
||||
// The step num of each epoch.
|
||||
size_t step_num_;
|
||||
size_t current_data_step_;
|
||||
size_t current_graph_step_;
|
||||
bool channel_open_;
|
||||
std::mutex channel_mutex_;
|
||||
std::condition_variable channel_;
|
||||
void *data_;
|
||||
size_t data_size_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_CHANNEL_H_
|
|
@ -53,6 +53,15 @@ namespace gpu {
|
|||
} \
|
||||
}
|
||||
|
||||
#define CHECK_CUDA_RET_WITH_ERROR_NOTRACE(expression, message) \
|
||||
{ \
|
||||
cudaError_t status = (expression); \
|
||||
if (status != cudaSuccess) { \
|
||||
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \
|
||||
<< cudaGetErrorString(status); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \
|
||||
{ \
|
||||
cudaError_t status = (expression); \
|
||||
|
@ -62,6 +71,15 @@ namespace gpu {
|
|||
} \
|
||||
}
|
||||
|
||||
#define CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(expression, message) \
|
||||
{ \
|
||||
cudaError_t status = (expression); \
|
||||
if (status != cudaSuccess) { \
|
||||
MS_LOG(EXCEPTION) << "CUDA Error: " << message << " | Error Number: " << status << " " \
|
||||
<< cudaGetErrorString(status); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_CUDNN_RET_WITH_EXCEPT(node, expression, message) \
|
||||
{ \
|
||||
cudnnStatus_t status = (expression); \
|
||||
|
|
|
@ -139,6 +139,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")
|
||||
|
|
Loading…
Reference in New Issue