!35124 Adjust cast op on Ascend for ClipByNorm

Merge pull request !35124 from JoyLvliang/correct_cast_op_for_ClopByNorm
This commit is contained in:
i-robot 2022-05-31 02:17:12 +00:00 committed by Gitee
commit 3099e0ccc9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 91 additions and 40 deletions

View File

@ -214,23 +214,26 @@ const AnfNodePtr ClipByNormSplit::Process(const FuncGraphPtr &func_graph, const
TypeId dst_type_id = kNumberTypeFloat32;
auto shape_vec = GetOutputInferShape(clip_by_norm);
auto x_type_id = common::AnfAlgo::GetPrevNodeOutputInferDataType(clip_by_norm, 0);
// Create `op1 = cast(x)` to float32 data type
auto x_cast = CreateCastNode(func_graph, inp_x, shape_vec, x_type_id, dst_type_id);
// Create `op2 = square(op1)` op
auto square = CreateSquareNode(func_graph, x_cast, shape_vec, dst_type_id);
// Create 'op3 = reduce_sum(op2)' op
auto reduce_sum = CreateReduceSumNode(func_graph, square, clip_by_norm, shape_vec, dst_type_id);
// Create `op1 = square(x)` op
auto square = CreateSquareNode(func_graph, inp_x, shape_vec, x_type_id);
// Create 'op2 = reduce_sum(op1)' op
auto reduce_sum = CreateReduceSumNode(func_graph, square, clip_by_norm, shape_vec, x_type_id);
// Create `op3 = cast(op2)` to float32 data type
auto reduce_sum_cast =
CreateCastNode(func_graph, reduce_sum, GetOutputInferShape(reduce_sum), x_type_id, dst_type_id);
// Create 'op4 = sqrt(op3)' op
auto sqrt = CreateSqrtNode(func_graph, reduce_sum, dst_type_id);
// Create 'op5 = cast(clip_norm)' to float32 data type.
auto sqrt = CreateSqrtNode(func_graph, reduce_sum_cast, dst_type_id);
// Create 'op5 = x * clip_norm' op
auto mul = CreateMulNode(func_graph, inp_x, inp_clip_norm, shape_vec, x_type_id);
// Create 'op6 = cast(clip_norm)' to float32 data type.
auto clip_norm_cast = CreateCastNode(func_graph, inp_clip_norm, GetOutputInferShape(inp_clip_norm),
common::AnfAlgo::GetOutputInferDataType(inp_clip_norm, 0), dst_type_id);
// Create 'op6 = x * op5' op
auto mul = CreateMulNode(func_graph, x_cast, clip_norm_cast, shape_vec, dst_type_id);
// Create `op7 = max(op5, op4)` op
// Create `op7 = max(op6, op4)` op
auto max = CreateMaxNode(func_graph, clip_norm_cast, sqrt, dst_type_id);
// Create 'op8 = op6 / op7' op
auto div = CreateDivNode(func_graph, mul, max, shape_vec, dst_type_id);
// Create 'op8 = cast(op5)' to float32 data type.
auto mul_cast = CreateCastNode(func_graph, mul, shape_vec, x_type_id, dst_type_id);
// Create 'op9 = op8 / op7' op
auto div = CreateDivNode(func_graph, mul_cast, max, shape_vec, dst_type_id);
return div;
}
} // namespace opt

View File

@ -312,7 +312,11 @@ void ClipByNormCpuKernelMod::ClipNormMulAndCmpLaunch(T *x_addr, float *div_outpu
if (x_shape_.empty()) { // The input x is a scalar tensor
float mul_output = div_output_addr[0] * static_cast<float>(clip_norm_addr[0]);
float x = static_cast<float>(x_addr[0]);
if (x * mul_output >= 0) {
output_addr[0] = (mul_output * mul_output) > (x * x) ? x : mul_output;
} else {
output_addr[0] = mul_output;
}
return;
}
BroadcastIterator broadcast_base_iter(x_shape_, clip_norm_shape_, output_shape_);
@ -325,7 +329,11 @@ void ClipByNormCpuKernelMod::ClipNormMulAndCmpLaunch(T *x_addr, float *div_outpu
float clip_norm = static_cast<float>(clip_norm_addr[iter.GetInputPosB()]);
float mul_output = clip_norm * div_out;
float x = static_cast<float>(x_addr[iter.GetInputPosA()]);
if (x * mul_output >= 0) {
output_addr[i] = (mul_output * mul_output) > (x * x) ? x : mul_output;
} else {
output_addr[i] = mul_output;
}
iter.GenNextPos();
}
};

View File

@ -18,36 +18,78 @@
#include "include/cuda_fp16.h"
template <typename T>
__global__ void CompAndCastKernel(const size_t size, const T *x, const T *temp_output_addr, float *output_addr) {
__global__ void AbsKernel(const size_t size, const T *in, T *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
output_addr[i] = temp_output_addr[i] * temp_output_addr[i] > x[i] * x[i] ? x[i] : temp_output_addr[i];
out[i] = (in[i] >= 0) ? in[i] : -in[i];
}
}
template <>
__global__ void CompAndCastKernel(const size_t size, const half *x, const half *temp_output_addr, float *output_addr) {
__global__ void AbsKernel(const size_t size, const float *in, float *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
output_addr[i] =
temp_output_addr[i] * temp_output_addr[i] > x[i] * x[i] ? __half2float(x[i]) : __half2float(temp_output_addr[i]);
out[i] = fabs(in[i]);
}
}
template <>
__global__ void CompAndCastKernel(const size_t size, const float *x, const float *temp_output_addr,
float *output_addr) {
__global__ void AbsKernel(const size_t size, const half *in, half *out) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
output_addr[i] = temp_output_addr[i] * temp_output_addr[i] > x[i] * x[i] ? x[i] : temp_output_addr[i];
float zero = 0;
out[i] = (in[i] >= __float2half(zero)) ? in[i] : -in[i];
}
}
template <typename T>
void CompAndCastOp(const size_t size, const T *x, const T *temp_output_addr, float *output_addr,
cudaStream_t cuda_stream) {
CompAndCastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, temp_output_addr, output_addr);
__global__ void CompKernel(const size_t size, const T *x, const T *temp_output_addr, float *output_addr) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (temp_output_addr[i] * x[i] >= 0) {
output_addr[i] = (temp_output_addr[i] * temp_output_addr[i]) > (x[i] * x[i]) ? x[i] : temp_output_addr[i];
} else {
output_addr[i] = temp_output_addr[i];
}
}
}
template CUDA_LIB_EXPORT void CompAndCastOp<float>(const size_t size, const float *x, const float *temp_output_addr,
template <>
__global__ void CompKernel(const size_t size, const half *x, const half *temp_output_addr, float *output_addr) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
float zero = 0;
if (temp_output_addr[i] * x[i] >= __float2half(zero)) {
output_addr[i] = (temp_output_addr[i] * temp_output_addr[i]) > (x[i] * x[i]) ? __half2float(x[i])
: __half2float(temp_output_addr[i]);
} else {
output_addr[i] = __half2float(temp_output_addr[i]);
}
}
}
template <>
__global__ void CompKernel(const size_t size, const float *x, const float *temp_output_addr, float *output_addr) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
if (temp_output_addr[i] * x[i] >= 0) {
output_addr[i] = (temp_output_addr[i] * temp_output_addr[i]) > (x[i] * x[i]) ? x[i] : temp_output_addr[i];
} else {
output_addr[i] = temp_output_addr[i];
}
}
}
template <typename T>
void AbsOp(const size_t size, const T *in, T *out, cudaStream_t cuda_stream) {
AbsKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, in, out);
}
template CUDA_LIB_EXPORT void AbsOp<float>(const size_t size, const float *in, float *out, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void AbsOp<half>(const size_t size, const half *in, half *out, cudaStream_t cuda_stream);
template <typename T>
void CompOp(const size_t size, const T *x, const T *temp_output_addr, float *output_addr, cudaStream_t cuda_stream) {
CompKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, temp_output_addr, output_addr);
}
template CUDA_LIB_EXPORT void CompOp<float>(const size_t size, const float *x, const float *temp_output_addr,
float *output_addr, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CompAndCastOp<half>(const size_t size, const half *x, const half *temp_output_addr,
template CUDA_LIB_EXPORT void CompOp<half>(const size_t size, const half *x, const half *temp_output_addr,
float *output_addr, cudaStream_t cuda_stream);

View File

@ -20,7 +20,10 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT void CompAndCastOp(const size_t size, const T *x, const T *temp_output_addr, float *output_addr,
CUDA_LIB_EXPORT void AbsOp(const size_t size, const T *in, T *out, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CompOp(const size_t size, const T *x, const T *temp_output_addr, float *output_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CLIP_BY_NORM_IMPL_CUH_

View File

@ -132,12 +132,7 @@ bool ClipByNormGpuKernelMod<T, S>::DoLaunch(const std::vector<AddressPtr> &input
Cast(x_size_ / sizeof(T), x_addr, x_float_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
// Launch `cudnnReduceTensorNorm2` operator to achieve `L2_norm` calculation, keep_dims = true.
if (all_match_) {
MS_LOG(DEBUG) << "The corresponding dimension of the `input_x` and `l2_norm_output` are all matched, running "
"device to device copy.";
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(l2norm_output_addr, x_float_addr, (x_size_ / sizeof(T)) * sizeof(float), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
kernel_name_ + " running cudaMemcpyAsync failed.");
AbsOp(x_size_ / sizeof(T), x_float_addr, l2norm_output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
constexpr float alpha = 1.0;
constexpr float beta = 0.0;
@ -165,7 +160,7 @@ bool ClipByNormGpuKernelMod<T, S>::DoLaunch(const std::vector<AddressPtr> &input
clip_norm_mul_output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Running compare between `input_x` and `upper output` and cast final output to float type.
CompAndCastOp(output_size_ / sizeof(float), x_float_addr, clip_norm_mul_output_addr, output_addr,
CompOp(output_size_ / sizeof(float), x_float_addr, clip_norm_mul_output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}