!40079 adapt ccsrc.gpu code to msvc
Merge pull request !40079 from chenminmin/master
This commit is contained in:
commit
7ba8a7c9f0
|
@ -170,7 +170,7 @@ bool CudaDriver::SyncStream(const CudaDeviceStream &stream) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CudaDriver::CreateEvent(CudaDeviceEvent *event, unsigned int flag) {
|
||||
bool CudaDriver::ConstructEvent(CudaDeviceEvent *event, unsigned int flag) {
|
||||
auto ret = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(event), flag);
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "cudaEventCreateWithFlags failed, ret[" << static_cast<int>(ret) << "], "
|
||||
|
|
|
@ -55,7 +55,7 @@ class CudaDriver {
|
|||
static bool DestroyStream(const CudaDeviceStream &stream);
|
||||
static bool SyncStream(const CudaDeviceStream &stream);
|
||||
|
||||
static bool CreateEvent(CudaDeviceEvent *event, unsigned int flag = cudaEventDefault);
|
||||
static bool ConstructEvent(CudaDeviceEvent *event, unsigned int flag = cudaEventDefault);
|
||||
static bool DestroyEvent(const CudaDeviceEvent &event);
|
||||
static bool RecordEvent(CudaDeviceEvent event, CudaDeviceStream stream = 0);
|
||||
static bool SyncEvent(const CudaDeviceEvent &event);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "plugin/device/gpu/hal/device/gpu_kernel_runtime.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <chrono>
|
||||
#include "include/common/debug/anf_dump_utils.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_device_address.h"
|
||||
#include "plugin/device/gpu/hal/device/cuda_driver.h"
|
||||
|
@ -465,8 +466,7 @@ void GPUKernelRuntime::AssignMemory(const session::KernelGraph &graph) {
|
|||
}
|
||||
|
||||
bool GPUKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_sink) {
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
std::chrono::system_clock::time_point start_time = std::chrono::system_clock::now();
|
||||
bool ret = true;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -497,10 +497,9 @@ bool GPUKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_sink)
|
|||
ret = LaunchKernels(graph);
|
||||
}
|
||||
}
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
std::chrono::system_clock::time_point end_time = std::chrono::system_clock::now();
|
||||
auto ms_duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
|
||||
uint64_t cost = ms_duration.count();
|
||||
MS_LOG(DEBUG) << "GPU kernel runtime run graph in " << cost << " us";
|
||||
return ret;
|
||||
}
|
||||
|
@ -932,8 +931,8 @@ void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, c
|
|||
float cost_time = 0;
|
||||
CudaDeviceStream start = nullptr;
|
||||
CudaDeviceStream end = nullptr;
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&start), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&end), "Failed to create event.");
|
||||
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, stream_), "Failed to record event to stream.");
|
||||
|
|
|
@ -34,7 +34,7 @@ void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address
|
|||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
MS_EXCEPTION_IF_NULL(host_addr.addr);
|
||||
CudaDeviceStream event = nullptr;
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event.");
|
||||
DeviceMemPtr device_ptr = const_cast<DeviceMemPtr>(device_address->GetPtr());
|
||||
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||
device_address->set_status(DeviceAddressStatus::kInDeviceToHost);
|
||||
|
@ -55,12 +55,12 @@ void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address,
|
|||
CudaDeviceStream start = nullptr;
|
||||
CudaDeviceStream end = nullptr;
|
||||
if (profiling) {
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&start), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&end), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, swap_in_stream_),
|
||||
"Failed to record CUDA event to swap in stream.");
|
||||
} else {
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end, cudaEventDisableTiming), "Failed to create CUDA event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ConstructEvent(&end, cudaEventDisableTiming), "Failed to create CUDA event.");
|
||||
}
|
||||
DeviceMemPtr device_ptr = const_cast<DeviceMemPtr>(device_address->GetPtr());
|
||||
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||
|
|
|
@ -158,10 +158,11 @@ class ExtractImagePatchesKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
if (padding == "VALID") {
|
||||
output_rows_ = std::ceil((input_row_size_ - patch_rows_eff + 1.f) / static_cast<float>(stride_row_));
|
||||
output_cols_ = std::ceil((input_col_size_ - patch_cols_eff + 1.f) / static_cast<float>(stride_col_));
|
||||
constexpr int64_t zero_value = 0;
|
||||
row_padding_top_ =
|
||||
std::max(0l, ((output_rows_ - 1) * stride_row_ + patch_rows_eff - input_row_size_) / kMidDividend);
|
||||
std::max(zero_value, ((output_rows_ - 1) * stride_row_ + patch_rows_eff - input_row_size_) / kMidDividend);
|
||||
col_padding_left_ =
|
||||
std::max(0l, ((output_cols_ - 1) * stride_col_ + patch_cols_eff - input_col_size_) / kMidDividend);
|
||||
std::max(zero_value, ((output_cols_ - 1) * stride_col_ + patch_cols_eff - input_col_size_) / kMidDividend);
|
||||
} else if (padding == "SAME") {
|
||||
output_rows_ = std::ceil(input_row_size_ / static_cast<float>(stride_row_));
|
||||
output_cols_ = std::ceil(input_col_size_ / static_cast<float>(stride_col_));
|
||||
|
|
|
@ -88,8 +88,9 @@ bool ScatterNdFunctorGPUKernelMod::Init(const BaseOperatorPtr &base_operator,
|
|||
kernel_name_ = base_operator->name();
|
||||
auto iter = kScatterNdFunctorTypeMap.find(kernel_name_);
|
||||
if (iter == kScatterNdFunctorTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only support these scatter functors: " << Map2Str(kScatterNdFunctorTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(EXCEPTION) << "Only support these scatter functors: "
|
||||
<< Map2Str<std::map, ScatterNdFunctorType>(kScatterNdFunctorTypeMap) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
}
|
||||
scatter_nd_functor_type_ = iter->second;
|
||||
|
||||
|
|
|
@ -61,8 +61,9 @@ class BesselHelperGpuKernel : public GpuKernelHelperBase {
|
|||
ResetResource();
|
||||
auto iter = kBesselOpTypeMap.find(kernel_name_);
|
||||
if (iter == kBesselOpTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: " << kernel::Map2Str(kBesselOpTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, BesselOptype>(kBesselOpTypeMap) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
return -1;
|
||||
}
|
||||
bessel_op_type_ = iter->second;
|
||||
|
@ -101,8 +102,9 @@ class BesselHelperGpuKernel : public GpuKernelHelperBase {
|
|||
iter->second(input_size_list_[0] / sizeof(T), input_addr, output_addr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: " << kernel::Map2Str(kBesselOpTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, BesselOptype>(kBesselOpTypeMap) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -221,26 +221,26 @@ class FractionalPoolHelperGpuKernel : public GpuKernelHelperBase {
|
|||
|
||||
int seed = InitSeed(seed_, seed2_, deterministic_);
|
||||
// Generate pooling sequence.
|
||||
int64_t height_cum_seq[output_shape_[kOutputShapeIndexH] + 1];
|
||||
int64_t width_cum_seq[output_shape_[kOutputShapeIndexW] + 1];
|
||||
flag = GeneratePoolingSequence(height_cum_seq, input_shape_[kInputShapeIndexH], output_shape_[kOutputShapeIndexH],
|
||||
pseudo_random_, seed);
|
||||
std::vector<int64_t> height_cum_seq(output_shape_[kOutputShapeIndexH] + 1);
|
||||
std::vector<int64_t> width_cum_seq(output_shape_[kOutputShapeIndexW] + 1);
|
||||
flag = GeneratePoolingSequence(height_cum_seq.data(), input_shape_[kInputShapeIndexH],
|
||||
output_shape_[kOutputShapeIndexH], pseudo_random_, seed);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GeneratePoolingSequence(width_cum_seq, input_shape_[kInputShapeIndexW], output_shape_[kOutputShapeIndexW],
|
||||
pseudo_random_, seed);
|
||||
flag = GeneratePoolingSequence(width_cum_seq.data(), input_shape_[kInputShapeIndexW],
|
||||
output_shape_[kOutputShapeIndexW], pseudo_random_, seed);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
auto cuda_ret = cudaMemcpy(row_pooling_sequence, height_cum_seq,
|
||||
auto cuda_ret = cudaMemcpy(row_pooling_sequence, height_cum_seq.data(),
|
||||
sizeof(int64_t) * (output_shape_[kOutputShapeIndexH] + 1), cudaMemcpyHostToDevice);
|
||||
if (cuda_ret != 0) {
|
||||
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
|
||||
return -1;
|
||||
}
|
||||
cuda_ret = cudaMemcpy(col_pooling_sequence, width_cum_seq,
|
||||
cuda_ret = cudaMemcpy(col_pooling_sequence, width_cum_seq.data(),
|
||||
sizeof(int64_t) * (output_shape_[kOutputShapeIndexW] + 1), cudaMemcpyHostToDevice);
|
||||
if (cuda_ret != 0) {
|
||||
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
|
||||
|
|
|
@ -94,8 +94,9 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
|||
ResetResource();
|
||||
auto iter = kUnaryOpTypeMap.find(kernel_name_);
|
||||
if (iter == kUnaryOpTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: " << kernel::Map2Str(kUnaryOpTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, UnaryOptype>(kUnaryOpTypeMap) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
return -1;
|
||||
}
|
||||
unary_op_type_ = iter->second;
|
||||
|
@ -146,8 +147,9 @@ class UnaryHelperGpuKernel : public GpuKernelHelperBase {
|
|||
iter->second(input_addr, output_addr, input_size_list_[0] / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: " << kernel::Map2Str(kUnaryOpTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, UnaryOptype>(kUnaryOpTypeMap) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUDA_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUDA_COMMON_H_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define uint unsigned int
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
|
@ -59,4 +61,17 @@ class CudaCommon {
|
|||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// some cuda op(such as cum_minmax) use isnan with int type, but msvc not support
|
||||
// so, implement its
|
||||
__device__ __forceinline__ bool IsNan(const int8_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const int16_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const int32_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const int64_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const uint8_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const uint16_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const uint32_t &x) { return false; }
|
||||
__device__ __forceinline__ bool IsNan(const uint64_t &x) { return false; }
|
||||
#endif // _MSC_VER
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUDA_COMMON_H_
|
||||
|
|
|
@ -144,7 +144,9 @@ int BesselGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
|
|||
std::vector<KernelAttr> BesselGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map.find(kernel_type_);
|
||||
if (iter == kernel_attr_map.end()) {
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: " << kernel::Map2Str(kernel_attr_map)
|
||||
MS_LOG(ERROR) << "For 'BesselOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, BesselPtrCreatorFunc>>>(
|
||||
kernel_attr_map)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -33,8 +33,8 @@ bool BroadcastOpGradGpuKernelMod::GetOpType() {
|
|||
};
|
||||
auto iter = broadcast_type_map.find(kernel_name_);
|
||||
if (iter == broadcast_type_map.end()) {
|
||||
MS_LOG(ERROR) << "For " << kernel::Map2Str(broadcast_type_map) << ", it only support max and min grad, but got "
|
||||
<< kernel_name_;
|
||||
MS_LOG(ERROR) << "For " << kernel::Map2Str<std::map, BroadcastGradOpType>(broadcast_type_map)
|
||||
<< ", it only support max and min grad, but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
op_type_ = iter->second;
|
||||
|
|
|
@ -31,7 +31,7 @@ bool BroadcastOpGradGradGpuKernelMod::GetOpType() {
|
|||
};
|
||||
auto iter = broadcast_grad_grad_op_type.find(kernel_name_);
|
||||
if (iter == broadcast_grad_grad_op_type.end()) {
|
||||
MS_LOG(ERROR) << "For " << kernel::Map2Str(broadcast_grad_grad_op_type)
|
||||
MS_LOG(ERROR) << "For " << kernel::Map2Str<std::map, BroadcastGradGradOpType>(broadcast_grad_grad_op_type)
|
||||
<< ", it only support max and min grad grad, but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -310,8 +310,10 @@ bool UnaryOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
|
|||
kernel_name_ = base_operator->name();
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_) << ", but got "
|
||||
<< kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::UnaryOpFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
|
@ -347,8 +349,10 @@ int UnaryOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
|
|||
std::vector<KernelAttr> UnaryOpGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_) << ", but got "
|
||||
<< kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'Unary op', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, UnaryOpGpuKernelMod::UnaryOpFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
@ -385,8 +389,10 @@ bool UnaryOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
|
||||
auto iter = func_map.find(kernel_name_);
|
||||
if (iter == func_map.end()) {
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: " << kernel::Map2Str(func_map) << " currently, but got "
|
||||
<< kernel_name_;
|
||||
MS_LOG(ERROR) << "For 'UnaryOp', only support these types: "
|
||||
<< kernel::Map2Str<std::map, std::function<void(const T *, T *, const size_t, cudaStream_t)>>(
|
||||
func_map)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
auto input_ptr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
|
|
|
@ -96,8 +96,11 @@ bool UnaryGradOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
kernel_name_ = base_operator->name();
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'UnaryGrad op', the kernel name must be in" << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'UnaryGrad op', the kernel name must be in"
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, UnaryGradOpGpuKernelMod::UnaryOpGradFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
size_t input_num = inputs.size();
|
||||
|
@ -157,8 +160,10 @@ bool UnaryGradOpGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
{kInvGrad, InvGrad<T>}};
|
||||
auto iter = func_map.find(kernel_name_);
|
||||
if (iter == func_map.end()) {
|
||||
MS_LOG(ERROR) << "For 'UnaryGrad', only support these types: " << kernel::Map2Str(func_map)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'UnaryGrad', only support these types: "
|
||||
<< kernel::Map2Str<std::map, std::function<void(const T *, const T *, T *, const size_t, cudaStream_t)>>(func_map)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
auto input_x_addr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
|
|
|
@ -55,8 +55,11 @@ bool ActivationFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
|
|||
kernel_name_ = base_operator->name();
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'Activation', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'Activation', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, ActivationFwdGpuKernelMod::ActivationFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
@ -93,8 +96,8 @@ int ActivationFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
}
|
||||
auto iter = activation_mode_map.find(kernel_name_);
|
||||
if (iter == activation_mode_map.end()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', only support these activations: " << kernel::Map2Str(activation_mode_map) << ", but got "
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only support these activations: "
|
||||
<< kernel::Map2Str<std::map, cudnnActivationMode_t>(activation_mode_map) << ", but got "
|
||||
<< kernel_name_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
@ -141,8 +144,11 @@ int ActivationFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, cons
|
|||
std::vector<KernelAttr> ActivationFwdGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'Activation', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'Activation', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, ActivationFwdGpuKernelMod::ActivationFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -54,8 +54,11 @@ bool ActivationGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
kernel_name_ = base_operator->name();
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'ActivationGrad', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'ActivationGrad', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, ActivationGradGpuKernelMod::ActivationGradFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
|
@ -92,8 +95,8 @@ int ActivationGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
}
|
||||
auto iter = activation_mode_map.find(kernel_name_);
|
||||
if (iter == activation_mode_map.end()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', only support these activations: " << kernel::Map2Str(activation_mode_map) << ", but got "
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', only support these activations: "
|
||||
<< kernel::Map2Str<std::map, cudnnActivationMode_t>(activation_mode_map) << ", but got "
|
||||
<< kernel_name_;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
@ -135,8 +138,11 @@ int ActivationGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, con
|
|||
std::vector<KernelAttr> ActivationGradGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'ActivationGrad', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'ActivationGrad', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, ActivationGradGpuKernelMod::ActivationGradFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -85,7 +85,8 @@ class BiasAddGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
|
||||
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
|
||||
auto cudnn_dims = std::max(num_dims, 4UL);
|
||||
constexpr size_t four_4D = 4;
|
||||
size_t cudnn_dims = std::max(num_dims, four_4D);
|
||||
std::unique_ptr<int[]> x_dims = std::make_unique<int[]>(cudnn_dims);
|
||||
std::unique_ptr<int[]> b_dims = std::make_unique<int[]>(cudnn_dims);
|
||||
for (size_t i = 0; i < cudnn_dims; i++) {
|
||||
|
|
|
@ -104,7 +104,8 @@ class BiasAddGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'C' character must be in 'format', but got " << format;
|
||||
}
|
||||
bias_size_ = LongToSizeClipNeg(dy_shape[pos]);
|
||||
auto num_dims_fix = std::max(num_dims_, 4UL);
|
||||
constexpr size_t four_4D = 4;
|
||||
size_t num_dims_fix = std::max(num_dims_, four_4D);
|
||||
for (size_t i = 0; i < num_dims_fix; i++) {
|
||||
dy_shape_.push_back((i < num_dims_) ? dy_shape[i] : 1);
|
||||
db_shape_.push_back((i == pos) ? dy_shape[i] : 1);
|
||||
|
@ -185,7 +186,8 @@ class BiasAddGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&op_desc_),
|
||||
"cudnnCreateOpTensorDescriptor failed");
|
||||
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
|
||||
auto cudnn_dims = std::max(num_dims_, 4UL);
|
||||
constexpr size_t four_4D = 4;
|
||||
size_t cudnn_dims = std::max(num_dims_, four_4D);
|
||||
std::unique_ptr<int[]> dy_dims = std::make_unique<int[]>(cudnn_dims);
|
||||
std::unique_ptr<int[]> db_dims = std::make_unique<int[]>(cudnn_dims);
|
||||
for (size_t i = 0; i < cudnn_dims; i++) {
|
||||
|
|
|
@ -503,8 +503,11 @@ std::map<std::string, std::vector<std::pair<KernelAttr, PoolingGradGpuKernelMod:
|
|||
std::vector<KernelAttr> PoolingGradGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'PoolingGradGpuKernelMod', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
MS_LOG(ERROR)
|
||||
<< "For 'PoolingGradGpuKernelMod', the kernel name must be in "
|
||||
<< kernel::Map2Str<std::map, std::vector<std::pair<KernelAttr, PoolingGradGpuKernelMod::PoolingGradFunc>>>(
|
||||
kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -92,12 +92,12 @@ bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Launch(const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(input_div_alpha);
|
||||
MS_EXCEPTION_IF_NULL(input_quant);
|
||||
const int kChannelLen = num_channels_;
|
||||
float alpha_no_grad[kChannelLen];
|
||||
memset_s(alpha_no_grad, kChannelLen * sizeof(float), 0, kChannelLen * sizeof(float));
|
||||
std::vector<float> alpha_no_grad(kChannelLen);
|
||||
memset_s(alpha_no_grad.data(), kChannelLen * sizeof(float), 0, kChannelLen * sizeof(float));
|
||||
|
||||
if (global_step_ >= quant_delay_) {
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * kChannelLen,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_,
|
||||
|
@ -106,7 +106,7 @@ bool FakeLearnedScaleQuantPerChannelGradGpuKernelMod::Launch(const std::vector<A
|
|||
neg_trunc_, num_channels_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * kChannelLen,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad.data(), sizeof(float) * kChannelLen,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
|
|
|
@ -200,8 +200,8 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
|
|||
device::gpu::CudaDeviceStream end = nullptr;
|
||||
float bind_cost = 0;
|
||||
float cross_cost = 0;
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::CreateEvent(&start), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::CreateEvent(&end), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::ConstructEvent(&start), "Failed to create event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::ConstructEvent(&end), "Failed to create event.");
|
||||
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
|
||||
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
|
||||
|
|
Loading…
Reference in New Issue