!363 clear the warmming scan by package

Merge pull request !363 from SanjayChan/labao
This commit is contained in:
mindspore-ci-bot 2020-04-16 09:18:44 +08:00 committed by Gitee
commit 58b013c319
12 changed files with 42 additions and 37 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -19,7 +19,6 @@
namespace mindspore {
namespace kernel {
DropoutGpuFwdKernel::DropoutGpuFwdKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),

View File

@ -18,7 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -132,7 +132,6 @@ class BatchNormFold2GpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -18,7 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -18,7 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -54,7 +54,6 @@ class CorrectionMulGpuKernel : public GpuKernel {
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W).";
return false;

View File

@ -19,7 +19,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)

View File

@ -61,7 +61,6 @@ class CorrectionMulGradGpuKernel : public GpuKernel {
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W).";
return false;

View File

@ -114,6 +114,36 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() {
workspace_size_list_.push_back(workspace_size_);
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, uintptr_t stream_ptr) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed.");
}
global_step_++;
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, uintptr_t stream_ptr) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
@ -126,11 +156,8 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null.";
}
// Allocate space for device copies
@ -143,30 +170,11 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
"Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_),
"Malloc gpu memory failed");
int total_size = input_size_ / sizeof(float);
bool symmetric = false;
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, total_size, channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed.");
}
global_step_++;
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
} else {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
}
// Cleanup

View File

@ -39,6 +39,11 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void InitSizeLists() override;
private:
void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, uintptr_t stream_ptr);
void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, uintptr_t stream_ptr);
size_t input_size_;
size_t min_size_;
size_t max_size_;