forked from mindspore-Ecosystem/mindspore
!1032 quantization aware training bug fix.
Merge pull request !1032 from SanjayChan/bug_fix
This commit is contained in:
commit
96e2f9cbbe
|
@ -20,8 +20,8 @@
|
|||
#include "device/gpu/cuda_common.h"
|
||||
#include "fake_quant_impl.cuh"
|
||||
|
||||
__global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min,
|
||||
const float* nudge_max, const float* scale, bool symmetric) {
|
||||
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
|
||||
const float *nudge_max, const float *scale, bool symmetric) {
|
||||
float input_x = 0.f;
|
||||
int nudge_input = 0;
|
||||
|
||||
|
@ -43,8 +43,8 @@ __global__ void FakeQuantize(const float* input, float* output, const int size,
|
|||
return;
|
||||
}
|
||||
|
||||
__global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max) {
|
||||
__global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
|
||||
const float *nudge_min, const float *nudge_max) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
|
||||
output[i] = 0;
|
||||
|
@ -55,15 +55,18 @@ __global__ void FakeQuantizeGrad(const float* input, const float* gradient, floa
|
|||
return;
|
||||
}
|
||||
|
||||
__global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min,
|
||||
const float quant_max, float* nudge_min, float* nudge_max, float* scale) {
|
||||
__global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min,
|
||||
const float quant_max, float *nudge_min, float *nudge_max, float *scale) {
|
||||
float zp_from_min = 0.f;
|
||||
if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) {
|
||||
*scale = 0.f;
|
||||
scale[0] = 0.f;
|
||||
nudge_max[0] = 0.f;
|
||||
nudge_min[0] = 0.f;
|
||||
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
|
||||
scale[0] = 0.f;
|
||||
zp_from_min = 0.f;
|
||||
} else {
|
||||
*scale = (*input_max - *input_min) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - *input_min / *scale;
|
||||
scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - input_min[0] / scale[0];
|
||||
}
|
||||
|
||||
float nudge_zp = 0.f;
|
||||
|
@ -75,59 +78,59 @@ __global__ void NudgeMinMax(const float* input_min, const float* input_max, cons
|
|||
nudge_zp = round(zp_from_min);
|
||||
}
|
||||
|
||||
*nudge_min = (quant_min - nudge_zp) * (*scale);
|
||||
*nudge_max = (quant_max - nudge_zp) * (*scale);
|
||||
nudge_min[0] = (quant_min - nudge_zp) * (scale[0]);
|
||||
nudge_max[0] = (quant_max - nudge_zp) * (scale[0]);
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max,
|
||||
__global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max,
|
||||
const float decay) {
|
||||
*input_min = decay * (min) + (1 - decay) * (*input_min);
|
||||
*input_min = *input_min > 0 ? 0 : *input_min;
|
||||
*input_max = decay * (max) + (1 - decay) * (*input_max);
|
||||
*input_max = *input_max < 0 ? 0 : *input_max;
|
||||
input_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
|
||||
input_min[0] = input_min[0] > 0 ? 0 : input_min[0];
|
||||
input_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
|
||||
input_max[0] = input_max[0] < 0 ? 0 : input_max[0];
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) {
|
||||
*input_min = min;
|
||||
*input_max = max;
|
||||
__global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) {
|
||||
input_min[0] = min > 0 ? 0 : min;
|
||||
input_max[0] = max < 0 ? 0 : max;
|
||||
}
|
||||
|
||||
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
|
||||
const float* scale, bool symmetric, cudaStream_t cuda_stream) {
|
||||
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
|
||||
const float *scale, bool symmetric, cudaStream_t cuda_stream) {
|
||||
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale,
|
||||
symmetric);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) {
|
||||
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
|
||||
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
|
||||
FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
|
||||
nudge_max);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) {
|
||||
NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
|
||||
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
|
||||
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) {
|
||||
NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
|
||||
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
|
||||
cudaStream_t cuda_stream) {
|
||||
float minel = 0.f;
|
||||
float maxel = 0.f;
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
|
||||
tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
|
||||
tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
|
||||
minel = tuple.first[0];
|
||||
maxel = tuple.second[0];
|
||||
|
||||
if (ema) {
|
||||
UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay);
|
||||
UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay);
|
||||
} else {
|
||||
UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel);
|
||||
UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,16 +17,16 @@
|
|||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
|
||||
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
|
||||
const float* scale, bool symmetric, cudaStream_t cuda_stream);
|
||||
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
|
||||
const float *scale, bool symmetric, cudaStream_t cuda_stream);
|
||||
|
||||
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream);
|
||||
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
|
||||
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
|
||||
|
||||
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream);
|
||||
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
|
||||
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream);
|
||||
|
||||
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
|
||||
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
|
|
|
@ -34,8 +34,8 @@
|
|||
* @param channel_num
|
||||
* @return
|
||||
*/
|
||||
__global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min,
|
||||
const float quant_max, float* nudge_min, float* nudge_max, float* scale,
|
||||
__global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min,
|
||||
const float quant_max, float *nudge_min, float *nudge_max, float *scale,
|
||||
int channel_num) {
|
||||
float zp_from_min = 0.f;
|
||||
float nudge_zp = 0.f;
|
||||
|
@ -62,8 +62,8 @@ __global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input
|
|||
}
|
||||
}
|
||||
|
||||
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
|
||||
void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
|
||||
float *nudge_min, float *nudge_max, float *scale, const int channel_num,
|
||||
cudaStream_t cuda_stream) {
|
||||
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num);
|
||||
|
@ -80,8 +80,8 @@ void CalNudgePerChannel(const float* input_min, const float* input_max, const fl
|
|||
* @param scale - array
|
||||
* @return
|
||||
*/
|
||||
__global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
|
||||
const float* nudge_min, const float* nudge_max, const float* scale,
|
||||
__global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
|
||||
const float *nudge_min, const float *nudge_max, const float *scale,
|
||||
bool symmetric) {
|
||||
float input_x = 0.f;
|
||||
int nudge_input = 0;
|
||||
|
@ -106,8 +106,8 @@ __global__ void FakeQuantizePerChannel(const float* input, float* output, const
|
|||
}
|
||||
}
|
||||
|
||||
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
|
||||
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
|
||||
void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
|
||||
const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric,
|
||||
cudaStream_t cuda_stream) {
|
||||
FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
|
||||
|
@ -121,10 +121,10 @@ void CalFakeQuantizePerChannel(const float* input, float* output, const int tota
|
|||
* @param max
|
||||
* @return
|
||||
*/
|
||||
__global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels,
|
||||
__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels,
|
||||
int per_channel_nums, bool ema, float ema_decay) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
|
||||
thrust::pair<float*, float*> sum =
|
||||
thrust::pair<float *, float *> sum =
|
||||
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
|
||||
if (ema) {
|
||||
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
|
||||
|
@ -133,25 +133,27 @@ __global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max,
|
|||
input_min[i] = sum.first[0];
|
||||
input_max[i] = sum.second[0];
|
||||
}
|
||||
input_min[i] = input_min[i] > 0 ? 0 : input_min[i];
|
||||
input_max[i] = input_max[i] < 0 ? 0 : input_max[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max,
|
||||
__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max,
|
||||
const float decay) {
|
||||
*input_min = decay * (min) + (1 - decay) * (*input_min);
|
||||
*input_max = decay * (max) + (1 - decay) * (*input_max);
|
||||
}
|
||||
|
||||
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size,
|
||||
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size,
|
||||
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
|
||||
int per_channel_num = total_size / channel_size;
|
||||
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
|
||||
}
|
||||
|
||||
__global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output,
|
||||
const int total_size, const int channel_size, const float* nudge_min,
|
||||
const float* nudge_max) {
|
||||
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
|
||||
const int total_size, const int channel_size, const float *nudge_min,
|
||||
const float *nudge_max) {
|
||||
int channel_idx = 0;
|
||||
int per_channel_num = total_size / channel_size;
|
||||
|
||||
|
@ -165,10 +167,9 @@ __global__ void FakeQuantizePerChannelGrad(const float* input, const float* grad
|
|||
}
|
||||
}
|
||||
|
||||
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
|
||||
const int channel_num, const float* nudge_min, const float* nudge_max,
|
||||
void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
|
||||
const int channel_num, const float *nudge_min, const float *nudge_max,
|
||||
cudaStream_t cuda_stream) {
|
||||
FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, gradient, output, total_num, channel_num, nudge_min, nudge_max);
|
||||
}
|
||||
|
||||
|
|
|
@ -114,8 +114,7 @@ class BatchNormFold2GpuKernel : public GpuKernel {
|
|||
|
||||
output_size_list_.push_back(input_size);
|
||||
|
||||
size_t workspace_size = 0;
|
||||
workspace_size_list_.push_back(workspace_size);
|
||||
workspace_size_list_.push_back(sizeof(int32_t));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -70,9 +70,12 @@ class BatchNormFold2GradGpuKernel : public GpuKernel {
|
|||
|
||||
int32_t current_step_host[1];
|
||||
size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(d_x, dout, x_size, cudaMemcpyDeviceToDevice), "Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
|
||||
BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
|
|
@ -55,12 +55,13 @@ class BatchNormFoldGpuKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
(void)workspace;
|
||||
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto mean = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto variance = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
int *current_step = reinterpret_cast<int *>(inputs[3]->addr);
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto mean = GetDeviceAddress<T>(inputs, 1);
|
||||
auto variance = GetDeviceAddress<T>(inputs, 2);
|
||||
int *current_step = GetDeviceAddress<int>(inputs, 3);
|
||||
int current_step_host[1];
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memoy failed.");
|
||||
if (x == nullptr) {
|
||||
MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null.";
|
||||
|
@ -78,15 +79,17 @@ class BatchNormFoldGpuKernel : public GpuKernel {
|
|||
MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null.";
|
||||
return false;
|
||||
}
|
||||
auto batch_mean = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto batch_std = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
auto running_mean = reinterpret_cast<T *>(outputs[2]->addr);
|
||||
auto running_std = reinterpret_cast<T *>(outputs[3]->addr);
|
||||
auto y = reinterpret_cast<T *>(workspace[0]->addr);
|
||||
auto batch_mean = GetDeviceAddress<T>(outputs, 0);
|
||||
auto batch_std = GetDeviceAddress<T>(outputs, 1);
|
||||
auto running_mean = GetDeviceAddress<T>(outputs, 2);
|
||||
auto running_std = GetDeviceAddress<T>(outputs, 3);
|
||||
auto y = GetDeviceAddress<T>(workspace, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_std, variance, output_size_, cudaMemcpyDeviceToDevice),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (!is_training_ || current_step_host[0] >= freeze_bn_) {
|
||||
|
|
|
@ -57,7 +57,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
|
|||
T *batch_std = GetDeviceAddress<T>(inputs, 4);
|
||||
int *current_step = GetDeviceAddress<int>(inputs, 5);
|
||||
int current_step_host[1];
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memoy failed.");
|
||||
if (d_batch_mean == nullptr) {
|
||||
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null.";
|
||||
|
@ -83,7 +84,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel {
|
|||
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null.";
|
||||
return false;
|
||||
}
|
||||
T *dx = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (!is_training_ || current_step_host[0] >= freeze_bn_) {
|
||||
ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
|
|
@ -60,7 +60,7 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
|
||||
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
|
||||
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
|
||||
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
|
@ -115,7 +115,6 @@ void FakeQuantGpuKernel::InitSizeLists() {
|
|||
|
||||
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
|
||||
(void)workspace;
|
||||
float *output = GetDeviceAddress<float>(outputs, 0);
|
||||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
|
@ -151,7 +150,8 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
}
|
||||
global_step_++;
|
||||
|
|
|
@ -93,7 +93,6 @@ void FakeQuantGradGpuKernel::InitSizeLists() {
|
|||
|
||||
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
|
||||
(void)workspace;
|
||||
float *output = GetDeviceAddress<float>(outputs, 0);
|
||||
float *gradient = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input = GetDeviceAddress<float>(inputs, 1);
|
||||
|
@ -133,8 +132,9 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice),
|
||||
"Copy gpu memory failed.");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
}
|
||||
global_step_++;
|
||||
return true;
|
||||
|
|
|
@ -107,11 +107,13 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(min_size_); // min
|
||||
input_size_list_.push_back(max_size_); // max
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
input_size_list_.push_back(input_size_); // input in tensor
|
||||
input_size_list_.push_back(min_size_); // min one scalar
|
||||
input_size_list_.push_back(max_size_); // max on scalar
|
||||
output_size_list_.push_back(output_size_); // output in tensor
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
|
||||
}
|
||||
|
||||
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
|
||||
|
@ -128,8 +130,9 @@ void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, floa
|
|||
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
|
||||
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
|
||||
"Copy gpu memory failed.");
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed.");
|
||||
}
|
||||
global_step_++;
|
||||
}
|
||||
|
@ -152,6 +155,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 2);
|
||||
float *d_scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
|
||||
|
@ -160,27 +166,12 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null.";
|
||||
}
|
||||
|
||||
// Allocate space for device copies
|
||||
float *d_scale = nullptr;
|
||||
float *d_nudge_min = nullptr;
|
||||
float *d_nudge_max = nullptr;
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), sizeof(float) * channel_out_),
|
||||
"Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), sizeof(float) * channel_out_),
|
||||
"Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_),
|
||||
"Malloc gpu memory failed");
|
||||
|
||||
if (training_) {
|
||||
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
||||
} else {
|
||||
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,9 @@ void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
|
|||
input_size_list_.push_back(min_size_); // min
|
||||
input_size_list_.push_back(max_size_); // max
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
|
||||
}
|
||||
|
||||
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
||||
|
@ -109,6 +111,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
|
|||
float *input = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 2);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 3);
|
||||
float *d_scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (gradient == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
|
||||
|
@ -125,28 +130,13 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
|
|||
|
||||
int total_size = input_size_ / sizeof(float);
|
||||
if (global_step_ >= quant_delay_) {
|
||||
float *d_scale = nullptr;
|
||||
float *d_nudge_min = nullptr;
|
||||
float *d_nudge_max = nullptr;
|
||||
// Allocate space for device copies
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), channel_out_ * sizeof(float)),
|
||||
"Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), channel_out_ * sizeof(float)),
|
||||
"Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), channel_out_ * sizeof(float)),
|
||||
"Malloc gpu memory failed");
|
||||
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// Cleanup
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice),
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed.");
|
||||
}
|
||||
global_step_++;
|
||||
|
|
Loading…
Reference in New Issue