From 90d5d5dab30ac3ac849c58da86011d22f07837b7 Mon Sep 17 00:00:00 2001 From: Erpim Date: Sat, 17 Apr 2021 09:49:16 +0800 Subject: [PATCH] add lsq quantization method --- ...ake_learned_scale_quant_perchannel_impl.cu | 113 +++++ ...ke_learned_scale_quant_perchannel_impl.cuh | 33 ++ .../fake_learned_scale_quant_perlayer_impl.cu | 97 ++++ ...fake_learned_scale_quant_perlayer_impl.cuh | 32 ++ ...arned_scale_quant_perchannel_gpu_kernel.cc | 130 ++++++ ...earned_scale_quant_perchannel_gpu_kernel.h | 57 +++ ..._scale_quant_perchannel_grad_gpu_kernel.cc | 135 ++++++ ...d_scale_quant_perchannel_grad_gpu_kernel.h | 57 +++ ...learned_scale_quant_perlayer_gpu_kernel.cc | 121 +++++ ..._learned_scale_quant_perlayer_gpu_kernel.h | 56 +++ ...ed_scale_quant_perlayer_grad_gpu_kernel.cc | 130 ++++++ ...ned_scale_quant_perlayer_grad_gpu_kernel.h | 56 +++ ...e_learned_scale_quant_grad_unify_mindir.cc | 218 +++++++++ ...ke_learned_scale_quant_grad_unify_mindir.h | 57 +++ .../ccsrc/backend/session/ascend_session.cc | 3 + mindspore/ccsrc/pipeline/jit/pipeline.cc | 18 +- mindspore/compression/export/quant_export.py | 22 +- mindspore/compression/quant/qat.py | 215 +++++++-- mindspore/compression/quant/quant_utils.py | 128 +++++- mindspore/compression/quant/quantizer.py | 3 + mindspore/core/base/core_ops.h | 4 + mindspore/nn/layer/quant.py | 270 +++++++++--- mindspore/ops/_grad/grad_quant_ops.py | 27 ++ mindspore/ops/_op_impl/_custom_op/__init__.py | 6 + .../fake_learned_scale_quant_perchannel.py | 125 ++++++ ...ake_learned_scale_quant_perchannel_grad.py | 191 ++++++++ ...rned_scale_quant_perchannel_grad_reduce.py | 88 ++++ .../fake_learned_scale_quant_perlayer.py | 117 +++++ .../fake_learned_scale_quant_perlayer_grad.py | 184 ++++++++ ...earned_scale_quant_perlayer_grad_reduce.py | 88 ++++ mindspore/ops/operations/_quant_ops.py | 323 ++++++++++++++ .../cv/mobilenetv2_quant/README_CN.md | 162 +++++-- .../official/cv/mobilenetv2_quant/Readme.md | 165 +++++-- .../official/cv/mobilenetv2_quant/eval.py | 51 ++- .../official/cv/mobilenetv2_quant/export.py | 33 +- .../scripts/run_lsq_infer.sh | 54 +++ .../scripts/run_lsq_train.sh | 177 ++++++++ .../cv/mobilenetv2_quant/src/config.py | 39 ++ .../src/mobilenetv2_mix_quant.py | 414 ++++++++++++++++++ .../official/cv/mobilenetv2_quant/train.py | 71 ++- .../lenet_quant/test_lenet_quant.py | 100 ++++- tests/ut/python/train/quant/test_quant.py | 54 +++ 42 files changed, 4157 insertions(+), 267 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py create mode 100644 mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py create mode 100644 model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_infer.sh create mode 100644 model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_train.sh create mode 100644 model_zoo/official/cv/mobilenetv2_quant/src/mobilenetv2_mix_quant.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cu new file mode 100644 index 00000000000..9e196785b3e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cu @@ -0,0 +1,113 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fake_learned_scale_quant_perchannel_impl.cuh" +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +__global__ void FakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, + float *input_quant, const int channel_num) { + int channel_idx = 0; + int per_channel_num = size / channel_num; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + // dequantize + output[i] = input_quant[i] * input_alpha[channel_idx]; + } + return; +} + +__global__ void FakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, + const float *input_quant, const bool neg_trunc, + const int channel_num) { + int channel_idx = 0; + int per_channel_num = size / channel_num; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + float grad_alpha_temp = 0.f; + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + if (input_div_alpha[i] > 1.0) { + grad_alpha_temp = gradient[i]; + grad_input[i] = 0; + } else if (input_div_alpha[i] < lower_bound) { + grad_alpha_temp = -gradient[i]; + grad_input[i] = 0; + } else { + grad_input[i] = gradient[i]; + grad_alpha_temp = (gradient[i] * (input_quant[i] - input_div_alpha[i])); + } + MsAtomicAdd(grad_alpha + channel_idx, grad_alpha_temp); + } + return; +} + +__global__ void LSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, + const int channel_num) { + float input_x; + int channel_idx = 0; + int per_channel_num = size / channel_num; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + channel_idx = floor(static_cast(i) / static_cast(per_channel_num)); + input_x = input[i] / input_alpha[channel_idx]; + input_div_alpha[i] = input_x; + input_x = max(input_x, lower_bound); + input_x = min(input_x, 1.0); + + // quantize + input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; + } + return; +} + +void CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, + const int channel_num, cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerChannel<<>>(output, size, input_alpha, + input_quant, channel_num); + return; +} + +void CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, const int size, + const float *input_div_alpha, const float *input_quant, + const bool neg_trunc, const int channel_num, cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerChannelGrad<<>>(grad_input, + grad_alpha, + gradient, + size, + input_div_alpha, + input_quant, + neg_trunc, + channel_num); + return; +} + +void CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, const int channel_num, + cudaStream_t cuda_stream) { + LSQNudgePerChannel<<>>(input, size, input_alpha, input_quant_max, + input_div_alpha, input_quant, neg_trunc, + channel_num); + return; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh new file mode 100644 index 00000000000..a835c962248 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalLSQNudgePerChannel(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, const int channel_num, + cudaStream_t cuda_stream); + +void CalFakeLearnedScaleQuantPerChannel(float *output, const int size, float *input_alpha, float *input_quant, + const int channel_num, cudaStream_t cuda_stream); + +void CalFakeLearnedScaleQuantPerChannelGrad(float *grad_input, float *grad_alpha, const float *gradient, const int size, + const float *input_div_alpha, const float *input_quant, + const bool neg_trunc, const int channel_num, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cu new file mode 100644 index 00000000000..e86b47d55ba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cu @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fake_learned_scale_quant_perlayer_impl.cuh" +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +__global__ void FakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, + float *input_quant) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + // dequantize + output[i] = input_quant[i] * input_alpha[0]; + } + return; +} + +__global__ void FakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, + const int size, const float *input_div_alpha, + const float *input_quant, const bool neg_trunc) { + float grad_alpha_temp = 0.f; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + if (input_div_alpha[i] > 1.0) { + grad_alpha_temp += gradient[i]; + grad_input[i] = 0; + } else if (input_div_alpha[i] < lower_bound) { + grad_alpha_temp -= gradient[i]; + grad_input[i] = 0; + } else { + grad_input[i] = gradient[i]; + grad_alpha_temp += (gradient[i] * (input_quant[i] - input_div_alpha[i])); + } + } + MsAtomicAdd(grad_alpha, grad_alpha_temp); + return; +} + +__global__ void LSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc) { + float input_x; + float lower_bound = -1.0 * !neg_trunc; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { + input_x = input[i] / input_alpha[0]; + input_div_alpha[i] = input_x; + input_x = max(input_x, lower_bound); + input_x = min(input_x, 1.0); + + // quantize + input_quant[i] = floor(input_x * input_quant_max[0] + 0.5f) / input_quant_max[0]; + } + return; +} + +void CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant, + cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerLayer<<>>(output, size, input_alpha, + input_quant); + return; +} + +void CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, const int size, + const float *input_div_alpha, const float *input_quant, const bool neg_trunc, + cudaStream_t cuda_stream) { + FakeLearnedScaleQuantPerLayerGrad<<>>(grad_input, + grad_alpha, + gradient, + size, + input_div_alpha, + input_quant, + neg_trunc); + return; +} + +void CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, cudaStream_t cuda_stream) { + LSQNudgePerLayer<<>>(input, size, input_alpha, input_quant_max, + input_div_alpha, input_quant, neg_trunc); + return; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh new file mode 100644 index 00000000000..9667766367b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERLAYER_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERLAYER_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalLSQNudgePerLayer(const float *input, const int size, float *input_alpha, float *input_quant_max, + float *input_div_alpha, float *input_quant, const bool neg_trunc, cudaStream_t cuda_stream); + +void CalFakeLearnedScaleQuantPerLayer(float *output, const int size, float *input_alpha, float *input_quant, + cudaStream_t cuda_stream); + +void CalFakeLearnedScaleQuantPerLayerGrad(float *grad_input, float *grad_alpha, const float *gradient, const int size, + const float *input_div_alpha, const float *input_quant, const bool neg_trunc, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_LEARNED_SCALE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc new file mode 100644 index 00000000000..4efa5b9045e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerChannelGpuKernel::FakeLearnedScaleQuantPerChannelGpuKernel() + : input_size_(0), + quant_num_(1), + global_step_(0), + quant_delay_(0), + training_(false), + neg_trunc_(false), + num_channels_(0) {} + +const std::vector &FakeLearnedScaleQuantPerChannelGpuKernel::GetInputSizeList() const { + return input_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerChannelGpuKernel::GetOutputSizeList() const { + return output_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeLearnedScaleQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num + << ", but FakeLearnedScaleQuantPerChannel GpuKernel OP needs 3 Input."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num + << ", but FakeLearnedScaleQuantPerChannel GpuKernel OP needs 1 output."; + } + + quant_delay_ = static_cast(GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"))); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + neg_trunc_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + InitSizeLists(); + return true; +} + +void FakeLearnedScaleQuantPerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(sizeof(float) * num_channels_); // alpha + input_size_list_.push_back(sizeof(float)); // quant_max + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant +} + +bool FakeLearnedScaleQuantPerChannelGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *input = GetDeviceAddress(inputs, 0); + float *input_alpha = GetDeviceAddress(inputs, 1); + float *input_quant_max = GetDeviceAddress(inputs, 2); + float *output = GetDeviceAddress(outputs, 0); + float *input_div_alpha = GetDeviceAddress(workspace, 0); + float *input_quant = GetDeviceAddress(workspace, 1); + + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(output); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + num_channels_, reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + num_channels_, reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerChannel(output, quant_num_, input_alpha, input_quant, num_channels_, + reinterpret_cast(stream_ptr)); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannel, FakeLearnedScaleQuantPerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h new file mode 100644 index 00000000000..e25c030cc5a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_gpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeLearnedScaleQuantPerChannelGpuKernel : public GpuKernel { + public: + FakeLearnedScaleQuantPerChannelGpuKernel(); + ~FakeLearnedScaleQuantPerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + int global_step_; + int quant_delay_; + bool training_; + bool neg_trunc_; + int num_channels_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc new file mode 100644 index 00000000000..7121050cc41 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perchannel_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerChannelGradGpuKernel::FakeLearnedScaleQuantPerChannelGradGpuKernel() + : input_size_(0), + workspace_size_(0), + quant_num_(1), + quant_delay_(0), + global_step_(0), + neg_trunc_(false), + num_channels_(0) {} + +const std::vector &FakeLearnedScaleQuantPerChannelGradGpuKernel::GetInputSizeList() const { + return input_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerChannelGradGpuKernel::GetOutputSizeList() const { + return output_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num + << ", but FakeLearnedScaleQuantPerChannelGrad GpuKernel OP needs 4 input."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num + << ", but FakeLearnedScaleQuantPerChannelGrad GpuKernel OP needs 2 output."; + } + + quant_delay_ = static_cast(GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less than 0, require larger than 0."; + } + + neg_trunc_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + InitSizeLists(); + return true; +} + +void FakeLearnedScaleQuantPerChannelGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // alpha + input_size_list_.push_back(sizeof(float)); // quant_max + output_size_list_.push_back(input_size_); // grad_input + output_size_list_.push_back(sizeof(float) * num_channels_); // grad_alpha + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant +} + +bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *grad_input = GetDeviceAddress(outputs, 0); + float *grad_alpha = GetDeviceAddress(outputs, 1); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_alpha = GetDeviceAddress(inputs, 2); + float *input_quant_max = GetDeviceAddress(inputs, 3); + float *input_div_alpha = GetDeviceAddress(workspace, 0); + float *input_quant = GetDeviceAddress(workspace, 1); + + MS_EXCEPTION_IF_NULL(grad_input); + MS_EXCEPTION_IF_NULL(grad_alpha); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + const int kChannelLen = num_channels_; + float alpha_no_grad[kChannelLen]; + memset_s(alpha_no_grad, kChannelLen * sizeof(float), 0, kChannelLen * sizeof(float)); + + if (global_step_ >= quant_delay_) { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * num_channels_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + num_channels_, reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerChannelGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, input_quant, + neg_trunc_, num_channels_, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * num_channels_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerChannelGrad, FakeLearnedScaleQuantPerChannelGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h new file mode 100644 index 00000000000..4774bdfa46c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perchannel_grad_gpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeLearnedScaleQuantPerChannelGradGpuKernel : public GpuKernel { + public: + FakeLearnedScaleQuantPerChannelGradGpuKernel(); + ~FakeLearnedScaleQuantPerChannelGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + int quant_delay_; + int global_step_; + bool neg_trunc_; + int num_channels_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc new file mode 100644 index 00000000000..b3ef2818c3b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerLayerGpuKernel::FakeLearnedScaleQuantPerLayerGpuKernel() + : input_size_(0), quant_num_(1), global_step_(0), quant_delay_(0), training_(false), neg_trunc_(false) {} + +const std::vector &FakeLearnedScaleQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeLearnedScaleQuantPerLayerGpuKernel::GetOutputSizeList() const { + return output_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeLearnedScaleQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num + << ", but FakeLearnedScaleQuantPerLayer GpuKernel OP needs 3 Input."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num + << ", but FakeLearnedScaleQuantPerLayer GpuKernel OP needs 1 output."; + } + + quant_delay_ = static_cast(GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"))); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + neg_trunc_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float) * quant_num_; + InitSizeLists(); + return true; +} + +void FakeLearnedScaleQuantPerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(sizeof(float)); // alpha + input_size_list_.push_back(sizeof(float)); // quant_max + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant +} + +bool FakeLearnedScaleQuantPerLayerGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *input = GetDeviceAddress(inputs, 0); + float *input_alpha = GetDeviceAddress(inputs, 1); + float *input_quant_max = GetDeviceAddress(inputs, 2); + float *output = GetDeviceAddress(outputs, 0); + float *input_div_alpha = GetDeviceAddress(workspace, 0); + float *input_quant = GetDeviceAddress(workspace, 1); + + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(output); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerLayer(output, quant_num_, input_alpha, input_quant, + reinterpret_cast(stream_ptr)); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayer, FakeLearnedScaleQuantPerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h new file mode 100644 index 00000000000..3e8216e7416 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_gpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeLearnedScaleQuantPerLayerGpuKernel : public GpuKernel { + public: + FakeLearnedScaleQuantPerLayerGpuKernel(); + ~FakeLearnedScaleQuantPerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + int global_step_; + int quant_delay_; + bool training_; + bool neg_trunc_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc new file mode 100644 index 00000000000..42fbd722b68 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_learned_scale_quant_perlayer_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeLearnedScaleQuantPerLayerGradGpuKernel::FakeLearnedScaleQuantPerLayerGradGpuKernel() + : input_size_(0), workspace_size_(0), quant_num_(1), quant_delay_(0), global_step_(0), neg_trunc_(false) {} + +const std::vector &FakeLearnedScaleQuantPerLayerGradGpuKernel::GetInputSizeList() const { + return input_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerLayerGradGpuKernel::GetOutputSizeList() const { + return output_size_list_; +} + +const std::vector &FakeLearnedScaleQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeLearnedScaleQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { + kernel_node_ = kernel_node; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num + << ", but FakeLearnedScaleQuantPerLayerGrad GpuKernel OP needs 4 input."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num + << ", but FakeLearnedScaleQuantPerLayerGrad GpuKernel OP needs 2 output."; + } + + quant_delay_ = static_cast(GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"))); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less than 0, require larger than 0."; + } + + neg_trunc_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("neg_trunc")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeLearnedScaleQuantPerLayerGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // alpha + input_size_list_.push_back(sizeof(float)); // quant_max + output_size_list_.push_back(input_size_); // grad_input + output_size_list_.push_back(sizeof(float)); // grad_alpha + workspace_size_list_.push_back(input_size_); // input_div_alpha + workspace_size_list_.push_back(input_size_); // input_quant +} + +bool FakeLearnedScaleQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *grad_input = GetDeviceAddress(outputs, 0); + float *grad_alpha = GetDeviceAddress(outputs, 1); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_alpha = GetDeviceAddress(inputs, 2); + float *input_quant_max = GetDeviceAddress(inputs, 3); + float *input_div_alpha = GetDeviceAddress(workspace, 0); + float *input_quant = GetDeviceAddress(workspace, 1); + + MS_EXCEPTION_IF_NULL(grad_input); + MS_EXCEPTION_IF_NULL(grad_alpha); + MS_EXCEPTION_IF_NULL(gradient); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(input_alpha); + MS_EXCEPTION_IF_NULL(input_quant_max); + MS_EXCEPTION_IF_NULL(input_div_alpha); + MS_EXCEPTION_IF_NULL(input_quant); + + const float alpha_no_grad[1] = {0.f}; + + if (global_step_ >= quant_delay_) { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CalLSQNudgePerLayer(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_, + reinterpret_cast(stream_ptr)); + CalFakeLearnedScaleQuantPerLayerGrad(grad_input, grad_alpha, gradient, quant_num_, input_div_alpha, input_quant, + neg_trunc_, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + CHECK_CUDA_RET_WITH_ERROR(kernel_node_, + cudaMemcpyAsync(grad_input, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeLearnedScaleQuantPerLayerGrad, FakeLearnedScaleQuantPerLayerGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h new file mode 100644 index 00000000000..1773db4678a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_learned_scale_quant_perlayer_grad_gpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeLearnedScaleQuantPerLayerGradGpuKernel : public GpuKernel { + public: + FakeLearnedScaleQuantPerLayerGradGpuKernel(); + ~FakeLearnedScaleQuantPerLayerGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + int quant_delay_; + int global_step_; + bool neg_trunc_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKE_LEARNED_SCALE_QUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.cc new file mode 100644 index 00000000000..f481f863273 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h" + +#include +#include + +#include "utils/utils.h" +#include "utils/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/trace_base.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node, + std::vector *lsq_perlayer_grad_d_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node); + const auto &lsq_perlayer_grad_inputs = lsq_perlayer_grad_node->inputs(); + if (lsq_perlayer_grad_inputs.size() < kFakeLearnedScaleQuantGradInputNum) { + MS_LOG(EXCEPTION) << "lsq_perlayer_grad_node has wrong inputs size." + << " trace: " << trace::DumpSourceLines(lsq_perlayer_grad_node); + } + std::vector lsq_perlayer_grad_d_inputs = { + NewValueNode(std::make_shared(kFakeLearnedScaleQuantPerLayerGradDOpName)), lsq_perlayer_grad_inputs[1], + lsq_perlayer_grad_inputs[2], lsq_perlayer_grad_inputs[3], lsq_perlayer_grad_inputs[4]}; + auto lsq_perlayer_grad_d = graph->NewCNode(lsq_perlayer_grad_d_inputs); + MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_d); + lsq_perlayer_grad_d->set_scope(lsq_perlayer_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(lsq_perlayer_grad_node, 0), + AnfAlgo::GetOutputInferDataType(lsq_perlayer_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(lsq_perlayer_grad_node, 0), + AnfAlgo::GetOutputInferShape(lsq_perlayer_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lsq_perlayer_grad_d.get()); + + AnfAlgo::CopyNodeAttr(kAttrNeg_trunc, lsq_perlayer_grad_node, lsq_perlayer_grad_d); + CreateMultipleOutputsOfAnfNode(graph, lsq_perlayer_grad_d, kFakeLearnedScaleQuantGradDOutputNum, + lsq_perlayer_grad_d_outputs); +} + +void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node, + const std::vector &lsq_perlayer_grad_d_outputs, + std::vector *lsq_perlayer_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node); + MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad_outputs); + const auto &lsq_perlayer_grad_inputs = lsq_perlayer_grad_node->inputs(); + if (lsq_perlayer_grad_inputs.size() < kFakeLearnedScaleQuantGradInputNum) { + MS_LOG(EXCEPTION) << "lsq_perlayer_grad_node has wrong inputs size" + << " trace: " << trace::DumpSourceLines(lsq_perlayer_grad_node); + } + if (lsq_perlayer_grad_d_outputs.size() != kFakeLearnedScaleQuantGradDOutputNum) { + MS_LOG(EXCEPTION) << "lsq_perlayer_grad_d_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(lsq_perlayer_grad_node); + } + std::vector lsq_perlayer_reduce_grad_inputs = { + NewValueNode(std::make_shared(kFakeLearnedScaleQuantPerLayerGradDReduceOpName)), + lsq_perlayer_grad_d_outputs[1]}; + auto lsq_perlayer_reduce_grad = graph->NewCNode(lsq_perlayer_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad); + lsq_perlayer_reduce_grad->set_scope(lsq_perlayer_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(lsq_perlayer_grad_node, 1)}; + auto shapes = {AnfAlgo::GetOutputInferShape(lsq_perlayer_grad_node, 1)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lsq_perlayer_reduce_grad.get()); + + (*lsq_perlayer_reduce_grad_outputs).push_back(lsq_perlayer_reduce_grad); +} + +void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node, + std::vector *lsq_perchannel_grad_d_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node); + const auto &lsq_perchannel_grad_inputs = lsq_perchannel_grad_node->inputs(); + if (lsq_perchannel_grad_inputs.size() < kFakeLearnedScaleQuantGradInputNum) { + MS_LOG(EXCEPTION) << "lsq_perchannel_grad_node has wrong inputs size." + << " trace: " << trace::DumpSourceLines(lsq_perchannel_grad_node); + } + std::vector lsq_perchannel_grad_d_inputs = { + NewValueNode(std::make_shared(kFakeLearnedScaleQuantPerChannelGradDOpName)), + lsq_perchannel_grad_inputs[1], lsq_perchannel_grad_inputs[2], lsq_perchannel_grad_inputs[3], + lsq_perchannel_grad_inputs[4]}; + auto lsq_perchannel_grad_d = graph->NewCNode(lsq_perchannel_grad_d_inputs); + MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_d); + lsq_perchannel_grad_d->set_scope(lsq_perchannel_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(lsq_perchannel_grad_node, 0), + AnfAlgo::GetOutputInferDataType(lsq_perchannel_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(lsq_perchannel_grad_node, 0), + AnfAlgo::GetOutputInferShape(lsq_perchannel_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lsq_perchannel_grad_d.get()); + + AnfAlgo::CopyNodeAttr(kAttrNeg_trunc, lsq_perchannel_grad_node, lsq_perchannel_grad_d); + AnfAlgo::CopyNodeAttr(kAttrChannelAxis, lsq_perchannel_grad_node, lsq_perchannel_grad_d); + CreateMultipleOutputsOfAnfNode(graph, lsq_perchannel_grad_d, kFakeLearnedScaleQuantGradDOutputNum, + lsq_perchannel_grad_d_outputs); +} + +void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node, + const std::vector &lsq_perchannel_grad_d_outputs, + std::vector *lsq_perchannel_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node); + MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad_outputs); + const auto &lsq_perchannel_grad_inputs = lsq_perchannel_grad_node->inputs(); + if (lsq_perchannel_grad_inputs.size() < kFakeLearnedScaleQuantGradInputNum) { + MS_LOG(EXCEPTION) << "lsq_perchannel_grad_node has wrong inputs size" + << " trace: " << trace::DumpSourceLines(lsq_perchannel_grad_node); + } + if (lsq_perchannel_grad_d_outputs.size() != kFakeLearnedScaleQuantGradDOutputNum) { + MS_LOG(EXCEPTION) << "lsq_perchannel_grad_d_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(lsq_perchannel_grad_node); + } + std::vector lsq_perchannel_reduce_grad_inputs = { + NewValueNode(std::make_shared(kFakeLearnedScaleQuantPerChannelGradDReduceOpName)), + lsq_perchannel_grad_d_outputs[1]}; + auto lsq_perchannel_reduce_grad = graph->NewCNode(lsq_perchannel_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad); + lsq_perchannel_reduce_grad->set_scope(lsq_perchannel_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(lsq_perchannel_grad_node, 1)}; + auto shapes = {AnfAlgo::GetOutputInferShape(lsq_perchannel_grad_node, 1)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lsq_perchannel_reduce_grad.get()); + AnfAlgo::CopyNodeAttr(kAttrChannelAxis, lsq_perchannel_grad_node, lsq_perchannel_reduce_grad); + (*lsq_perchannel_reduce_grad_outputs).push_back(lsq_perchannel_reduce_grad); +} +} // namespace +const BaseRef FakeLearnedScaleQuantPerLayerGradUnifyMindIR::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kFakeLearnedScaleQuantPerLayerGradOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr FakeLearnedScaleQuantPerLayerGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + + std::vector lsq_perlayer_grad_d_outputs; + CreateOutputsOfLSQPerLayerGradD(func_graph, cnode, &lsq_perlayer_grad_d_outputs); + if (lsq_perlayer_grad_d_outputs.size() != kFakeLearnedScaleQuantGradOutputNum) { + MS_LOG(EXCEPTION) << "fake_learned_scale_quant_perlayer_grad_d_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(node); + } + + std::vector lsq_perlayer_reduce_grad_outputs; + CreateOutputsOfLSQPerLayerReduceGrad(func_graph, cnode, lsq_perlayer_grad_d_outputs, + &lsq_perlayer_reduce_grad_outputs); + if (lsq_perlayer_reduce_grad_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "fake_learned_scale_quant_perlayer_reduce_grad_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(node); + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), lsq_perlayer_grad_d_outputs[0], + lsq_perlayer_reduce_grad_outputs[0]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} + +const BaseRef FakeLearnedScaleQuantPerChannelGradUnifyMindIR::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kFakeLearnedScaleQuantPerChannelGradOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr FakeLearnedScaleQuantPerChannelGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, + const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + + std::vector lsq_perchannel_grad_d_outputs; + CreateOutputsOfLSQPerChannelGradD(func_graph, cnode, &lsq_perchannel_grad_d_outputs); + if (lsq_perchannel_grad_d_outputs.size() != kFakeLearnedScaleQuantGradOutputNum) { + MS_LOG(EXCEPTION) << "fake_learned_scale_quant_perchannel_grad_d_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(node); + } + + std::vector lsq_perchannel_reduce_grad_outputs; + CreateOutputsOfLSQPerChannelReduceGrad(func_graph, cnode, lsq_perchannel_grad_d_outputs, + &lsq_perchannel_reduce_grad_outputs); + if (lsq_perchannel_reduce_grad_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "fake_learned_scale_quant_perchannel_reduce_grad_outputs has wrong size" + << " trace: " << trace::DumpSourceLines(node); + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), lsq_perchannel_grad_d_outputs[0], + lsq_perchannel_reduce_grad_outputs[0]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h b/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h new file mode 100644 index 00000000000..c00e4047c93 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +constexpr size_t kFakeLearnedScaleQuantGradOutputNum = 2; +constexpr size_t kFakeLearnedScaleQuantGradInputNum = 5; +constexpr size_t kFakeLearnedScaleQuantGradDOutputNum = 2; +constexpr auto kFakeLearnedScaleQuantPerLayerGradOpName = "FakeLearnedScaleQuantPerLayerGrad"; +constexpr auto kFakeLearnedScaleQuantPerLayerGradDOpName = "FakeLearnedScaleQuantPerLayerGradD"; +constexpr auto kFakeLearnedScaleQuantPerLayerGradDReduceOpName = "FakeLearnedScaleQuantPerLayerGradDReduce"; +constexpr auto kFakeLearnedScaleQuantPerChannelGradOpName = "FakeLearnedScaleQuantPerChannelGrad"; +constexpr auto kFakeLearnedScaleQuantPerChannelGradDOpName = "FakeLearnedScaleQuantPerChannelGradD"; +constexpr auto kFakeLearnedScaleQuantPerChannelGradDReduceOpName = "FakeLearnedScaleQuantPerChannelGradDReduce"; + +constexpr auto kAttrNeg_trunc = "neg_trunc"; +constexpr auto kAttrChannelAxis = "channel_axis"; + +class FakeLearnedScaleQuantPerLayerGradUnifyMindIR : public PatternProcessPass { + public: + explicit FakeLearnedScaleQuantPerLayerGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("fake_learned_scale_quant_perlayer_grad_unify_mindir", multigraph) {} + ~FakeLearnedScaleQuantPerLayerGradUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +class FakeLearnedScaleQuantPerChannelGradUnifyMindIR : public PatternProcessPass { + public: + explicit FakeLearnedScaleQuantPerChannelGradUnifyMindIR(bool multigraph = true) + : PatternProcessPass("fake_learned_scale_quant_perchannel_grad_unify_mindir", multigraph) {} + ~FakeLearnedScaleQuantPerChannelGradUnifyMindIR() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_FAKE_LEARNED_SCALE_QUANT_GRAD_UNIFY_MINDIR_H_ diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 6b62316c8eb..0a277c3875c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -39,6 +39,7 @@ #include "backend/optimizer/ascend/mindir/maxpool_with_argmax_unify_mindir.h" #include "backend/optimizer/ascend/mindir/conv2d_unify_mindir.h" #include "backend/optimizer/ascend/mindir/optimizer_unify_output.h" +#include "backend/optimizer/ascend/mindir/fake_learned_scale_quant_grad_unify_mindir.h" #include "backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h" #include "backend/optimizer/ascend/mindir/slice_grad_unify_mindir.h" #include "backend/optimizer/ascend/mindir/avg_pool_grad_unify_mindir.h" @@ -374,6 +375,8 @@ void AscendSession::UnifyMindIR(const KernelGraphPtr &graph) { unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); + unify_mindir_pm->AddPass(std::make_shared()); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index cf7d233a25a..15f8132cc2c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -397,7 +397,9 @@ void ExecutorPy::GetWeightInfo(const CNodePtr &root_node, const AnfNodePtr &weig CNodePtr cnode = nullptr; auto is_quant_cnode = [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || - IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); + IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) || + IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) || + IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel); }; while (!is_quant_cnode(x)) { if (count >= max_depth) { @@ -452,7 +454,9 @@ std::map> ExecutorPy: std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); auto is_quant_cnode = [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || - IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); + IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel) || + IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerLayer) || + IsPrimitiveCNode(node, prim::kPrimFakeLearnedScaleQuantPerChannel); }; for (const auto &node : nodes) { auto root_node = node->cast(); @@ -461,7 +465,15 @@ std::map> ExecutorPy: } auto weight = root_node->input(2); if (!is_quant_cnode(weight)) { - continue; + auto tuple_node = weight->cast(); + if (tuple_node != nullptr) { + auto fake_node = tuple_node->input(1); + if (!is_quant_cnode(fake_node)) { + continue; + } else { + weight = fake_node; + } + } } // get parameter weight's name auto cnode = weight->cast(); diff --git a/mindspore/compression/export/quant_export.py b/mindspore/compression/export/quant_export.py index 472170fe832..e625c937128 100644 --- a/mindspore/compression/export/quant_export.py +++ b/mindspore/compression/export/quant_export.py @@ -56,6 +56,7 @@ class ExportToQuantInferNetwork: self.input_zero_point = round(mean) self.data_type = mstype.int8 self.network = copy.deepcopy(network) + self.network_bk = copy.deepcopy(network) self.all_parameters = {p.name: p for p in self.network.get_parameters()} self.get_inputs_table(inputs) self.mean = mean @@ -83,6 +84,7 @@ class ExportToQuantInferNetwork: """convert network's quant subcell to deploy subcell""" # Calculate the scale and zero point w_minq_name = cell_core.fake_quant_weight.minq.name + w_maxq_name = cell_core.fake_quant_weight.maxq.name np_type = mstype.dtype_to_nptype(self.data_type) param_dict = dict() param_dict["filter_maxq"] = None @@ -102,16 +104,23 @@ class ExportToQuantInferNetwork: quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) info = self.quant_info_table.get(w_minq_name, None) + if not info: + info = self.quant_info_table.get(w_maxq_name, None) if info: - fake_quant_a_in_op, minq_name = info + _, minq_name = info if minq_name == 'input': scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ self.input_scale, self.input_zero_point, 'None', 'None' else: - maxq = self.all_parameters[minq_name[:-4] + "maxq"] - minq = self.all_parameters[minq_name] + fake_quant_a_in_prefix = minq_name[:-5] + cells = self.network_bk.cells_and_names() + for cell in cells: + if cell[0].endswith(fake_quant_a_in_prefix): + fake_quant_a_in = cell[1] + break + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ - quant_utils.scale_zp_max_min_from_data(fake_quant_a_in_op, minq, maxq, np_type) + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type) else: # skip quant layer scale_a_in, zp_a_in = 1.0, 0.0 @@ -140,9 +149,8 @@ class ExportToQuantInferNetwork: weight_b = weight bias_b = bias # apply the quant - fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer - weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits, - fake_quant_weight_op.narrow_range) + weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, cell_core.fake_quant_weight.num_bits, + cell_core.fake_quant_weight.narrow_range) if bias is not None: bias = Tensor(bias / scale_a_in / scale_w, mstype.int32) diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py index e2ff6177e37..4fb99001525 100644 --- a/mindspore/compression/quant/qat.py +++ b/mindspore/compression/quant/qat.py @@ -22,15 +22,15 @@ aware training, MindSpore provides conversion functions to convert the trained m """ import re - import mindspore.context as context - +import numpy as np from ... import nn, ops from ..._checkparam import Validator, Rel from ...nn.layer import quant from ...ops import functional as F from ..common import QuantDtype from .quantizer import Quantizer, OptimizeOption +from .quant_utils import compute_KL_threshold __all__ = ["QuantizationAwareTraining", "create_quant_config"] @@ -41,7 +41,8 @@ def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQ quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), per_channel=(False, False), symmetric=(False, False), - narrow_range=(False, False)): + narrow_range=(False, False), + mode="DEFAULT"): r""" Config the observer type of weights and data flow with quant params. @@ -62,6 +63,8 @@ def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQ element represents data flow. Default: (False, False) narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False) + mode (String): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported. + Default: ("DEFAULT") Returns: QuantConfig, Contains the observer type of weight and activation. @@ -70,10 +73,10 @@ def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQ raise ValueError("Arg 'per_channel' second element must be 'False'.") weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], per_channel=per_channel[0], symmetric=symmetric[0], - narrow_range=narrow_range[0]) + narrow_range=narrow_range[0], mode=mode) act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], per_channel=per_channel[-1], symmetric=symmetric[-1], - narrow_range=narrow_range[-1]) + narrow_range=narrow_range[-1], mode=mode) return quant.QuantConfig(weight=weight_observer, activation=act_observer) @@ -103,14 +106,24 @@ class _AddFakeQuantAfterSubCell(nn.Cell): def __init__(self, subcell, **kwargs): super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) self.subcell = subcell - self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, - max_init=6, + self.mode = "DEFAULT" + self.max_init = 6 + self.min_init = -6 + + if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]: + self.mode = "LEARNED_SCALE" + self.max_init = 16 + self.min_init = -16 + + self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init, + max_init=self.max_init, ema=True, quant_dtype=kwargs["quant_dtype"], quant_delay=kwargs["quant_delay"], per_channel=kwargs["per_channel"], symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"]) + narrow_range=kwargs["narrow_range"], + mode=self.mode) def construct(self, *data): output = self.subcell(*data) @@ -128,7 +141,8 @@ class QuantizationAwareTraining(Quantizer): quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during eval. The first element represents weights and second element represents data flow. Default: (0, 0) quant_dtype (Union[QuantDtype, list, tuple]): Datatype to use for quantize weights and activations. The first - element represents weights and second element represents data flow. + element represents weights and second element represents data flow. It is necessary to consider the + precision support of hardware devices in the practical quantization infer scenario. Default: (QuantDtype.INT8, QuantDtype.INT8) per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represents weights @@ -139,7 +153,11 @@ class QuantizationAwareTraining(Quantizer): narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. The first element represents weights and the second element represents data flow. Default: (False, False) optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently only - support QAT. Default: OptimizeOption.QAT + support QAT and LEARNED_SCALE (Note that, if both QAT and LEARNED_SCALE are configured, LEARNED_SCALE has + a higher priority. LEARNED_SCALE currently only work under some constraints, which includes: freeze_bn=0, + quant_delay=0, symmetric=Ture, narrow_range=True, More specifically, for operators such as ReLu and ReLu6, + which only have positive values, we add a negative truncation to optimize this scenario, and narrow_range + will automatically match to False). Default: OptimizeOption.QAT one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True. Examples: @@ -218,11 +236,27 @@ class QuantizationAwareTraining(Quantizer): self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold") self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, nn.DenseBnAct: self._convert_dense} + self.mode = "DEFAULT" + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + self.mode = "LEARNED_SCALE" + if not self.weight_symmetric or not self.act_symmetric: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support " + "symmetric=(True, True) for quant") + if not self.weight_range or not self.act_range: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) " + "for quant") + if self.freeze_bn != 0: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, " + "but get freeze_bn={}".format(self.freeze_bn)) + if self.weight_qdelay != 0 or self.act_qdelay != 0: + raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)") self.quant_config = create_quant_config(quant_delay=quant_delay, quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, - narrow_range=narrow_range) + narrow_range=narrow_range, + mode=self.mode) + self.eps = 1e-5 def _convert_op_name(self, name): pattern = re.compile(r'([A-Z]{1})') @@ -247,7 +281,7 @@ class QuantizationAwareTraining(Quantizer): if context.get_context('device_target') not in support_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) - if OptimizeOption.QAT in self.optimize_option: + if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option: network.update_cell_prefix() network = self._convert_subcells2quant(network) network.update_cell_type("quant") @@ -274,7 +308,18 @@ class QuantizationAwareTraining(Quantizer): if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) - # add FakeQuant OP after OP in while list + # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell. + if isinstance(network, (nn.FakeQuantWithMinMaxObserver, + nn.Conv2dBnFoldQuantOneConv, + nn.Conv2dBnFoldQuant, + nn.Conv2dBnWithoutFoldQuant, + nn.Conv2dQuant, + nn.DenseQuant, + nn.ActQuant, + nn.TensorAddQuant, + nn.MulQuant)): + return network + add_list = [] for name in network.__dict__: if name[0] == '_': @@ -289,7 +334,8 @@ class QuantizationAwareTraining(Quantizer): quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, - narrow_range=self.act_range) + narrow_range=self.act_range, + optimize_option=self.optimize_option) prefix = self._convert_op_name(prim_op.name) if network.param_prefix: prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) @@ -302,10 +348,22 @@ class QuantizationAwareTraining(Quantizer): """ convert Conv2d cell to quant cell """ + min_init = -6 + max_init = 6 + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + subcell_weight_para = subcell.conv.weight.data.asnumpy() + if subcell.has_bn: + scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / + np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype) + self.quant_config = self.quant_config._replace( + weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) + conv_inner = subcell.conv if subcell.has_bn: + bn_inner = subcell.batchnorm if self.bn_fold: - bn_inner = subcell.batchnorm if self.one_conv_fold: conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels, conv_inner.out_channels, @@ -344,11 +402,7 @@ class QuantizationAwareTraining(Quantizer): conv_inner.beta = subcell.batchnorm.beta conv_inner.moving_mean = subcell.batchnorm.moving_mean conv_inner.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False else: - bn_inner = subcell.batchnorm conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, conv_inner.out_channels, kernel_size=conv_inner.kernel_size, @@ -368,20 +422,15 @@ class QuantizationAwareTraining(Quantizer): conv_inner.batchnorm.beta = subcell.batchnorm.beta conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False else: - conv_inner = quant.Conv2dQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - has_bias=conv_inner.has_bias, - quant_config=self.quant_config, + conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, + dilation=conv_inner.dilation, group=conv_inner.group, + has_bias=conv_inner.has_bias, quant_config=self.quant_config, quant_dtype=self.weight_dtype) # change original network Conv2D OP parameters to quant network conv_inner.weight = subcell.conv.weight @@ -392,18 +441,28 @@ class QuantizationAwareTraining(Quantizer): subcell.activation = self._convert_activation(subcell.activation) elif subcell.after_fake: subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range) + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, per_channel=self.act_channel, + symmetric=self.act_symmetric, narrow_range=self.act_range, + optimize_option=self.optimize_option) return subcell def _convert_dense(self, subcell): """ convert dense cell to quant cell """ + min_init = -6 + max_init = 6 + if OptimizeOption.LEARNED_SCALE in self.optimize_option: + subcell_weight_para = subcell.dense.weight.data.asnumpy() + if subcell.has_bn: + scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / + np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype) + self.quant_config = self.quant_config._replace( + weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) + dense_inner = subcell.dense dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner.out_channels, @@ -424,7 +483,8 @@ class QuantizationAwareTraining(Quantizer): quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, - narrow_range=self.act_range) + narrow_range=self.act_range, + optimize_option=self.optimize_option) return subcell def _convert_activation(self, activation): @@ -434,6 +494,7 @@ class QuantizationAwareTraining(Quantizer): act_class = activation.__class__ act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] + if act_class in act_list: return quant.ActQuant(activation=activation, quant_config=self.quant_config, @@ -445,3 +506,79 @@ class QuantizationAwareTraining(Quantizer): quant_config=self.quant_config, quant_dtype=self.act_dtype) raise ValueError("Unsupported activation in auto quant: ", act_class) + + def _KL_init(self, subcell_weight_para, weight_dtype): + """ + Calculate the value of max_init and min_init with compute_KL_threshold. + """ + if self.weight_channel: + max_init = [compute_KL_threshold(weight_para_each, weight_dtype) + for weight_para_each in subcell_weight_para] + min_init = [-x for x in max_init] + else: + max_init = [compute_KL_threshold(subcell_weight_para, weight_dtype)] + min_init = [-x for x in max_init] + return min_init, max_init + + def set_mixed_bits(self, network, strategy): + r""" + Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE` + optimize_option. + Input: + network (Cell): input network + strategy (List): the quantization strategy for layers that need to be quantified (eg. [[8], [8], + ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the + convolution layer is supported. + Output: + network (Cell) + """ + if OptimizeOption.LEARNED_SCALE not in self.optimize_option: + raise ValueError("The `set_mixed_bits` function is currently only valid for `LEARNED_SCALE` " + "optimize_option.") + + self.quantizable_idx = [] + pass_cell = None + for i, cell_and_name in enumerate(network.cells_and_names()): + cell = cell_and_name[1] + if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell: + self.quantizable_idx.append(i) + + assert len(self.quantizable_idx) == len(strategy) + quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(self.quantizable_idx, strategy)} + type_map = { + QuantDtype.INT2.num_bits: QuantDtype.INT2, + QuantDtype.INT3.num_bits: QuantDtype.INT3, + QuantDtype.INT4.num_bits: QuantDtype.INT4, + QuantDtype.INT5.num_bits: QuantDtype.INT5, + QuantDtype.INT6.num_bits: QuantDtype.INT6, + QuantDtype.INT7.num_bits: QuantDtype.INT7, + QuantDtype.INT8.num_bits: QuantDtype.INT8 + } + for i, cell_and_name in enumerate(network.cells_and_names()): + cell = cell_and_name[1] + if i not in self.quantizable_idx: + continue + else: + if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)): + cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]] + if isinstance(cell, nn.Conv2dBnAct): + subcell_weight_para = cell.conv.weight.data.asnumpy() + if hasattr(cell.conv, 'gamma'): + scale_factor = (cell.conv.gamma.data.asnumpy() / + np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype) + cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, + min_init=min_init, + max_init=max_init) + elif isinstance(cell, nn.DenseBnAct): + subcell_weight_para = cell.dense.weight.data.asnumpy() + if hasattr(cell.dense, 'gamma'): + scale_factor = (cell.dense.gamma.data.asnumpy() / + np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype) + cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, + min_init=min_init, + max_init=max_init) + return network diff --git a/mindspore/compression/quant/quant_utils.py b/mindspore/compression/quant/quant_utils.py index 890e33b6dd4..df74169e611 100644 --- a/mindspore/compression/quant/quant_utils.py +++ b/mindspore/compression/quant/quant_utils.py @@ -15,9 +15,10 @@ """Quantization utils.""" import numpy as np +from mindspore._checkparam import Validator +from ... import nn - -__all__ = ["load_nonquant_param_into_quant_net"] +__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"] def cal_quantization_params(input_min, @@ -25,7 +26,8 @@ def cal_quantization_params(input_min, data_type, num_bits=8, symmetric=False, - narrow_range=False): + narrow_range=False, + neg_trunc=False): r""" Calculate quantization params for scale and zero point. @@ -36,6 +38,7 @@ def cal_quantization_params(input_min, num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. + neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False. Returns: scale (numpy.ndarray): quantization param. @@ -65,17 +68,14 @@ def cal_quantization_params(input_min, quant_min = quant_min + 1 # calculate scale - if symmetric: + if symmetric and not neg_trunc: input_max = np.maximum(-input_min, input_max) input_min = -input_max scale = (input_max - input_min) / (quant_max - quant_min) # calculate zero point - if symmetric: - zp = np.zeros(input_min.shape) - else: - zp_double = quant_min - input_min / scale - zp = np.floor(zp_double + 0.5) + zp_double = quant_min - input_min / scale + zp = np.floor(zp_double + 0.5) return scale, zp @@ -135,16 +135,16 @@ def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=Fals def scale_zp_max_min_from_fake_quant_cell(cell, data_type): - """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" + """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`.""" minq = cell.minq.data.asnumpy() maxq = cell.maxq.data.asnumpy() - op = cell.fake_quant_infer scale, zp = cal_quantization_params( minq, maxq, data_type, - num_bits=op.num_bits, - symmetric=op.symmetric, - narrow_range=op.narrow_range) + num_bits=cell.num_bits, + symmetric=cell.symmetric, + narrow_range=cell.narrow_range, + neg_trunc=cell.neg_trunc) return scale, zp, maxq, minq @@ -267,6 +267,80 @@ def without_fold_batchnorm(weight, cell_quant): return weight, bias +def compute_KL_threshold(data, bitwidth): + r""" + Using KL-J Distance to calculate the clip threshold. + + Args: + - **data** (NumpyArray) - Data observed to calculate the threshold for quantization, + - **bitwidth** (QuantDtype) - The datatype of quantization. + Outputs: + Tensor with Shape 1. Threshold to calculate the data. + """ + bitwidth = bitwidth.num_bits + + data_min = 0 + data_max = np.abs(data).max() + if data_max < 1e-5: + return 1e-5 + hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(data_min, data_max), density=True) + hist = hist / np.sum(hist) + cumsum = np.cumsum(hist) + bit_pow_range = pow(2, int(bitwidth) - 1) + threshold = [] + scaling_factor = [] + kl = [] + if bit_pow_range + 1 > len(bin_edges) - 1: + th_layer_out = bin_edges[-1] + return float(th_layer_out) + for i in range(bit_pow_range + 1, len(bin_edges), 1): + threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0]) + threshold = np.concatenate((threshold, [threshold_tmp])) + scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1) + scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp])) + # forward interpolation + cumsum_tmp = np.copy(cumsum) + cumsum_tmp[(i - 1):] = 1 + fwd_x = np.linspace(0.0, 1.0, bit_pow_range) + fwd_xp = np.linspace(0.0, 1.0, i) + fwd_fp = cumsum_tmp[:i] + forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp) + # backward interpolation + bwd_x = np.linspace(0.0, 1.0, i) + bwd_xp = np.linspace(0.0, 1.0, bit_pow_range) + bwd_fp = forward_interp + backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp) + cumsum_tmp[:i] = backward_interp + kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J + kl = np.concatenate((kl, [kl_tmp])) + th_layer_out = threshold[np.argmin(kl)] + threshold = float(th_layer_out) + if threshold < 1e-5: + threshold = 1e-5 + return threshold + + +def query_quant_layers(network): + r""" + Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the + quantization layers are queried before graph compile optimization in the graph mode, thus may be appear some + redundant quantized layers, which are not exist in practical execution. + + Input: + network (Cell): input network + + Returns: + None + """ + network = Validator.check_isinstance("network", network, nn.Cell) + tplt = "{0:60}\t{1:10}" + for cell_and_name in network.cells_and_names(): + cell_name = cell_and_name[0] + cell = cell_and_name[1] + if isinstance(cell, nn.FakeQuantWithMinMaxObserver): + print(tplt.format(cell_name, cell.quant_dtype)) + + def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): r""" Load fp32 model parameters into quantization model. @@ -287,7 +361,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param 'moving_mean': iter(list(filter(lambda item: item[0].endswith('moving_mean'), params_dict.items()))), 'moving_variance': iter(list(filter(lambda item: item[0].endswith('moving_variance'), params_dict.items()))), 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))), - 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))) + 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))), + 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items()))) } for name, param in quant_model.parameters_and_names(): @@ -300,3 +375,26 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param if value_param: param.set_data(value_param[1].data) print(f'init model param {name} with checkpoint param {value_param[0]}') + + + # Perform KL_init when learned scale quantization is executed. + for cell_and_name in quant_model.cells_and_names(): + cell = cell_and_name[1] + if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, + nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": + subcell_weight_para = cell.weight.data.asnumpy() + if hasattr(cell, 'gamma'): + scale_factor = (cell.gamma.data.asnumpy() / + np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5)) + subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) + + if cell.fake_quant_weight.per_channel: + max_init = [compute_KL_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype) + for weight_para_each in subcell_weight_para] + min_init = [-x for x in max_init] + else: + max_init = [compute_KL_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)] + min_init = [-x for x in max_init] + + cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype, + min_init=min_init, max_init=max_init) diff --git a/mindspore/compression/quant/quantizer.py b/mindspore/compression/quant/quantizer.py index 24dd0d8f395..9a7e457418d 100644 --- a/mindspore/compression/quant/quantizer.py +++ b/mindspore/compression/quant/quantizer.py @@ -29,6 +29,9 @@ class OptimizeOption(Enum): # using quantization aware training QAT = "QAT" + # using the learned scale quantization + LEARNED_SCALE = "LEARNED_SCALE" + def __str__(self): return self.value diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index c5d7a3e7c87..82c9bdc5140 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -295,6 +295,10 @@ inline const PrimitivePtr kPrimOnesLike = std::make_shared("OnesLike" inline const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +inline const PrimitivePtr kPrimFakeLearnedScaleQuantPerLayer = + std::make_shared("FakeLearnedScaleQuantPerLayer"); +inline const PrimitivePtr kPrimFakeLearnedScaleQuantPerChannel = + std::make_shared("FakeLearnedScaleQuantPerChannel"); inline const PrimitivePtr kPrimFakeQuantWithMinMaxVars = std::make_shared("FakeQuantWithMinMaxVars"); inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared("SparseApplyFtrl"); diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 549d85b62b4..278fbb483ff 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -30,6 +30,7 @@ import mindspore.context as context from .normalization import BatchNorm2d from .activation import get_activation, ReLU from ..cell import Cell +from ... import nn from ...ops.operations import _quant_ops as Q __all__ = [ @@ -215,6 +216,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): r""" Quantization aware operation which provides the fake quantization observer function on data with min and max. + The detail of the quantization mode `DEFAULT` is described as below: + The running min/max :math:`x_{min}` and :math:`x_{max}` are computed as: .. math:: @@ -269,10 +272,59 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): output = u_X * scale + u_{min} \end{array} + The detail of the quantization mode `LEARNED_SCALE` is described as below: + + The fake quant output is computed as: + + .. math:: + + \bar{X}=\left\{\begin{matrix} + clip\left ( \frac{X}{maxq},0,1\right ) \qquad \quad if\quad neg\_trunc\\ + clip\left ( \frac{X}{maxq},-1,1\right )\qquad \ if\quad otherwise + \end{matrix}\right. \\ + + output=\frac{floor\left ( \bar{X}\ast Q_{max}+0.5 \right ) \ast scale }{Q_{max}} + + where X is the input tensor. + where :math:`Q_{max}` (quant_max) is decided by quant_dtype and neg_trunc, for example, if quant_dtype=INT8 + and neg_trunc works, :math:`Q_{max} = 256` , otherwise math:`Q_{max} = 127`. + + The maxq is updated by training, and its gradient is calculated as follows: + + .. math:: + + \frac{\partial \ output}{\partial \ maxq} & = \left\{\begin{matrix} + -\frac{X}{maxq}+\left \lfloor \frac{X}{maxq} \right \rceil \qquad if\quad bound_{lower}< \frac{X}{maxq}< 1\\ + -1 \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\le bound_{lower}\\ + 1 \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\ge 1 \qquad \quad + \end{matrix}\right. \\ + + bound_{lower}= + \end{align}\left\{\begin{matrix} + 0\qquad \quad if\quad neg\_trunc\\ + -1\qquad if\quad otherwise + \end{matrix}\right. + + Then minq is computed as: + + .. math:: + + minq=\left\{\begin{matrix} + 0 \qquad \qquad \quad if\quad neg\_trunc\\ + -maxq\qquad if\quad otherwise + \end{matrix}\right. + + When exporting, the scale and zero point zp is computed as: + + .. math:: + + scale=\frac{maxq}{quant\_max} ,\quad zp=0 \\ + + zp is equal to 0 consistently, due to the LEARNED_SCALE`s symmetric nature. Args: - min_init (int, float): The initialized min value. Default: -6. - max_init (int, float): The initialized max value. Default: 6. + min_init (int, float, list): The initialized min value. Default: -6. + max_init (int, float, list): The initialized max value. Default: 6. ema (bool): The exponential Moving Average algorithm updates min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. @@ -282,7 +334,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. quant_delay (int): Quantization delay parameters according to the global step. Default: 0. - + neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False. + mode (string): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported. + Default: ("DEFAULT") Inputs: - **input** (Tensor) - The input of FakeQuantWithMinMaxObserver. @@ -290,7 +344,7 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): Tensor, with the same type and shape as the `input`. Raises: - TypeError: If `min_init` or `max_init` is neither int nor float. + TypeError: If `min_init` or `max_init` is not int, float or list. TypeError: If `quant_delay` is not an int. TypeError: If `min_init` is not less than `max_init`. TypeError: If `quant_delay` is not greater than or equal to 0. @@ -318,18 +372,24 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): quant_dtype=QuantDtype.INT8, symmetric=False, narrow_range=False, - quant_delay=0): + quant_delay=0, + neg_trunc=False, + mode="DEFAULT"): """Initialize FakeQuantWithMinMaxObserver""" super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range, num_channels=num_channels) - Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__) - Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__) - Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) + Validator.check_value_type("min_init", min_init, [int, float, list], type(self).__name__) + Validator.check_value_type("max_init", max_init, [int, float, list], type(self).__name__) + if isinstance(max_init, (int, float)) and isinstance(min_init, (int, float)): + Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) + elif not np.greater(max_init, min_init).all(): + raise ValueError("`min_init` is not less than `max_init`, please reset the initial value.") Validator.check_non_negative_int(quant_delay, 'quant_delay') self.min_init = min_init self.max_init = max_init self.quant_dtype = quant_dtype + self.num_bits = quant_dtype.num_bits self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel @@ -338,43 +398,124 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range + self.neg_trunc = neg_trunc + self.mode = mode self.is_ascend = context.get_context('device_target') == "Ascend" + self.Neg = P.Neg() - # init tensor min and max for fake quantized operation - if self.per_channel: - min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) - max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) - else: - min_array = np.array([self.min_init]).astype(np.float32) - max_array = np.array([self.max_init]).astype(np.float32) - self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) - self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) + min_array = self._get_init_array(self.min_init) + max_array = self._get_init_array(self.max_init) - # init fake quant relative op - if self.per_channel: - quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) - ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) - else: - quant_fun = Q.FakeQuantPerLayer - ema_fun = Q.MinMaxUpdatePerLayer + if self.mode == "DEFAULT": + # init tensor min and max for fake quantized operation + self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) + self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) + + # init fake quant relative op + if self.per_channel: + quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) + ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) + else: + quant_fun = Q.FakeQuantPerLayer + ema_fun = Q.MinMaxUpdatePerLayer + + self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) + if self.is_ascend: + self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=self.quant_delay) + self.fake_quant_infer = self.fake_quant_train + else: + quant_fun = partial(quant_fun, + ema=self.ema, + ema_decay=ema_decay, + num_bits=self.quant_dtype.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + quant_delay=self.quant_delay) + self.fake_quant_train = quant_fun(training=True) + self.fake_quant_infer = quant_fun(training=False) + elif self.mode == "LEARNED_SCALE": + if not self.symmetric: + raise ValueError("The 'LEARNED_SCALE' mode only support symmetric quant, please set symmetric to True.") + if self.neg_trunc: + min_array = self._get_init_array(0) + self.narrow_range = False + elif not self.narrow_range: + raise ValueError("The 'LEARNED_SCALE' mode only support narrow_range=True config, " + "except for neg_trunc=True scenario.") + + self._calculate_quant_max() + + self.minq = Parameter(Tensor(min_array), name='minq') + self.maxq = Parameter(Tensor(max_array), name='maxq') + self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32)), + name="quant_max", requires_grad=False) + + # init fake quant relative op + if self.per_channel: + quant_fun = partial(Q.FakeLearnedScaleQuantPerChannel, channel_axis=self.channel_axis) + else: + quant_fun = Q.FakeLearnedScaleQuantPerLayer - self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) - if self.is_ascend: - self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) - self.fake_quant_infer = self.fake_quant_train - else: quant_fun = partial(quant_fun, - ema=self.ema, - ema_decay=ema_decay, - num_bits=self.quant_dtype.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) + quant_delay=self.quant_delay, + neg_trunc=self.neg_trunc) self.fake_quant_train = quant_fun(training=True) self.fake_quant_infer = quant_fun(training=False) + else: + raise ValueError("Invalid mode, currently only valid for `DEFAULT` and `LEARNED_SCALE` mode.") + + def reset(self, quant_dtype=QuantDtype.INT8, min_init=-6, max_init=6): + r""" + Reset the quant max parameter (eg. 256) and the initial value of the minq parameter and maxq parameter, + this function is currently only valid for `LEARNED_SCALE` mode. + """ + if self.mode == "LEARNED_SCALE": + self.quant_dtype = quant_dtype + self.num_bits = quant_dtype.num_bits + self._calculate_quant_max() + if self.neg_trunc: + min_init = 0 + + self.min_init = min_init + self.max_init = max_init + min_array = self._get_init_array(self.min_init) + max_array = self._get_init_array(self.max_init) + self.minq.set_data(Tensor(min_array)) + self.maxq.set_data(Tensor(max_array)) + self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32))) + else: + raise ValueError("The `reset` function is currently only valid for `LEARNED_SCALE` mode.") + + def _get_init_array(self, init_date): + """ + Convert the initial value to array. + """ + if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels: + raise ValueError("The length of the min_init/max_init list shuold be equal to num_channels for " + "perchannel quant scenario, but get {}".format(len(init_date))) + if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1: + raise ValueError("The length of the min_init/max_init list shuold be 1 for perlayer quant " + "scenario, but get {}".format(len(init_date))) + + if isinstance(init_date, list): + min_max_array = np.array(init_date).astype(np.float32) + elif self.per_channel and not isinstance(init_date, list): + min_max_array = np.array([init_date] * self.num_channels).astype(np.float32) + else: + min_max_array = np.array([init_date]).astype(np.float32) + return min_max_array + + def _calculate_quant_max(self): + """ + The quantization range is calculated according to num_bits. + """ + if not self.neg_trunc: + self._quant_max = (1 << (self.num_bits - 1)) - 1 + else: + self._quant_max = (1 << self.num_bits) - 1 def extend_repr(self): s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ @@ -385,13 +526,21 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): return s def construct(self, x): - if self.training: - min_up, max_up = self.ema_update(x, self.minq, self.maxq) - self.minq = min_up - self.maxq = max_up - out = self.fake_quant_train(x, self.minq, self.maxq) + if self.mode == "LEARNED_SCALE": + if self.training: + out = self.fake_quant_train(x, self.maxq, self.quant_max) + if not self.neg_trunc: + self.minq = self.Neg(self.maxq) + else: + out = self.fake_quant_infer(x, self.maxq, self.quant_max) else: - out = self.fake_quant_infer(x, self.minq, self.maxq) + if self.training: + min_up, max_up = self.ema_update(x, self.minq, self.maxq) + self.minq = min_up + self.maxq = max_up + out = self.fake_quant_train(x, self.minq, self.maxq) + else: + out = self.fake_quant_infer(x, self.minq, self.maxq) return out @@ -539,12 +688,13 @@ class Conv2dBnFoldQuantOneConv(Cell): requires_grad=False) # initialize fake ops - self.fake_quant_weight = quant_config.weight(min_init=-6, - max_init=6, - ema=False, + self.fake_quant_weight = quant_config.weight(ema=False, channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) + self.freeze_bn = False + if self.fake_quant_weight.mode == "LEARNED_SCALE": + self.freeze_bn = True self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, momentum=self.momentum, data_format=self.format) @@ -579,6 +729,9 @@ class Conv2dBnFoldQuantOneConv(Cell): if self.fake: weight = self.fake_quant_weight(weight) conv = self.conv(x, weight) + + if self.freeze_bn: + return conv + self.reshape((self.beta - self.gamma * self.moving_mean / running_std), (1, -1, 1, 1)) scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) if self.enable_default_train: scale_factor = P.Reciprocal()(scale_factor) @@ -730,9 +883,7 @@ class Conv2dBnFoldQuant(Cell): requires_grad=False) # initialize fake ops - self.fake_quant_weight = quant_config.weight(min_init=-6, - max_init=6, - ema=False, + self.fake_quant_weight = quant_config.weight(ema=False, channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) @@ -886,9 +1037,7 @@ class Conv2dBnWithoutFoldQuant(Cell): weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') - self.fake_quant_weight = quant_config.weight(min_init=-6, - max_init=6, - ema=False, + self.fake_quant_weight = quant_config.weight(ema=False, channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) @@ -1005,9 +1154,7 @@ class Conv2dQuant(Cell): dilation=self.dilation, group=self.group) channel_axis = 0 - self.fake_quant_weight = quant_config.weight(min_init=-6, - max_init=6, - ema=False, + self.fake_quant_weight = quant_config.weight(ema=False, channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) @@ -1111,9 +1258,7 @@ class DenseQuant(Cell): if activation is not None and not isinstance(self.activation, (Cell, Primitive)): raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) self.activation_flag = self.activation is not None - self.fake_quant_weight = quant_config.weight(min_init=-6, - max_init=6, - ema=False, + self.fake_quant_weight = quant_config.weight(ema=False, channel_axis=0, num_channels=out_channels, quant_dtype=quant_dtype) @@ -1198,6 +1343,8 @@ class ActQuant(_QuantActivation): quant_config=quant_config_default, quant_dtype=QuantDtype.INT8): super(ActQuant, self).__init__() + act_class = activation.__class__ + act_list = [nn.ReLU, nn.ReLU6] self.act = Validator.check_isinstance("activation", activation, Cell) self.fake_before = Validator.check_bool(fake_before, "fake_before") if self.fake_before: @@ -1206,11 +1353,14 @@ class ActQuant(_QuantActivation): ema=ema, ema_decay=ema_decay, quant_dtype=quant_dtype) + + neg_trunc = bool(act_class in act_list) self.fake_quant_act = quant_config.activation(min_init=-6, max_init=6, ema=ema, ema_decay=ema_decay, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + neg_trunc=neg_trunc) def construct(self, x): if self.fake_before: diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index 9abbd4119d2..6983257928d 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -200,3 +200,30 @@ def get_bprop_wts_arq(self): return (dout, zeros_like(w_min), zeros_like(w_max)) return bprop + + +@bprop_getters.register(Q.FakeLearnedScaleQuantPerLayer) +def get_bprop_fakequant_with_learned_scale_perlayer(self): + """Generate bprop for FakeLearnedScaleQuantPerLayer for GPU""" + op = Q.FakeLearnedScaleQuantPerLayerGrad(quant_delay=self.quant_delay, + neg_trunc=self.neg_trunc) + + def bprop(x, x_alpha, x_quant_max, out, dout): + dx, dalpha = op(dout, x, x_alpha, x_quant_max) + return dx, dalpha, zeros_like(x_quant_max) + + return bprop + + +@bprop_getters.register(Q.FakeLearnedScaleQuantPerChannel) +def get_bprop_fakequant_with_learned_scale_perchannel(self): + """Generate bprop for FakeLearnedScaleQuantPerChannel for GPU""" + op = Q.FakeLearnedScaleQuantPerChannelGrad(quant_delay=self.quant_delay, + neg_trunc=self.neg_trunc, + channel_axis=self.channel_axis) + + def bprop(x, x_alpha, x_quant_max, out, dout): + dx, dalpha = op(dout, x, x_alpha, x_quant_max) + return dx, dalpha, zeros_like(x_quant_max) + + return bprop diff --git a/mindspore/ops/_op_impl/_custom_op/__init__.py b/mindspore/ops/_op_impl/_custom_op/__init__.py index 5fe583a60fc..63e94190104 100644 --- a/mindspore/ops/_op_impl/_custom_op/__init__.py +++ b/mindspore/ops/_op_impl/_custom_op/__init__.py @@ -14,3 +14,9 @@ # ============================================================================ """custom ops""" +from .fake_learned_scale_quant_perlayer import _fake_learned_scale_quant_perlayer_tbe +from .fake_learned_scale_quant_perlayer_grad import _fake_learned_scale_quant_perlayer_grad_d_tbe +from .fake_learned_scale_quant_perlayer_grad_reduce import _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe +from .fake_learned_scale_quant_perchannel import _fake_learned_scale_quant_perchannel_tbe +from .fake_learned_scale_quant_perchannel_grad import _fake_learned_scale_quant_perchannel_grad_d_tbe +from .fake_learned_scale_quant_perchannel_grad_reduce import _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py new file mode 100644 index 00000000000..da2a74701a1 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py @@ -0,0 +1,125 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerChannel op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_learned_scale_quant_perchannel_op_info = TBERegOp("FakeLearnedScaleQuantPerChannel") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perchannel.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perchannel") \ + .partial_flag(True) \ + .attr("neg_trunc", "optional", "bool", "all") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "input_x", None, "required", None) \ + .input(1, "alpha", None, "required", None) \ + .input(2, "quant_max", None, "required", None) \ + .output(0, "out", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perchannel_op_info) +def _fake_learned_scale_quant_perchannel_tbe(): + """FakeLearnedScaleQuantPerChannel TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perchannel") +def fake_learned_scale_quant_perchannel_compute(input_data, alpha_data, quant_max_data, neg_trunc, + kernel_name="fake_learned_scale_quant_perchannel"): + """FakeLearnedScaleQuantPerChannel""" + input_shape = te.lang.cce.util.shape_to_list(input_data.shape) + alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype) + quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype) + + input_x = te.lang.cce.vdiv(input_data, alpha_data) + + if neg_trunc: + input_x = te.lang.cce.round_to(input_x, 1.0, 0.0) + else: + input_x = te.lang.cce.round_to(input_x, 1.0, -1.0) + + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vmul(input_x, quant_max_data), 0.5)) + input_quant = te.lang.cce.vdiv(nudge_input, quant_max_data) + res = te.lang.cce.vmul(input_quant, alpha_data) + + return res + + +@util.check_input_type(dict, dict, dict, dict, bool, int, str) +def fake_learned_scale_quant_perchannel(input_x, alpha, quant_max, out, neg_trunc, channel_axis, + kernel_name="fake_learned_scale_quant_perchannel"): + """FakeLearnedScaleQuantPerChannel""" + input_shape = input_x.get("shape") + input_x_shape_ = input_x.get("ori_shape") + input_x_format = input_x.get("format") + input_dtype = input_x.get("dtype") + alpha_shape = alpha.get("ori_shape") + alpha_dtype = alpha.get("dtype") + quant_max_shape = quant_max.get("ori_shape") + quant_max_dtype = quant_max.get("dtype") + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and input_x_shape_[0] != alpha_shape[0] and input_x_shape_[1] == alpha_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis + + util.check_kernel_name(kernel_name) + util.check_shape_rule(input_shape) + util.check_shape_rule(alpha_shape, 1, 1, input_x_shape_[channel_axis_]) + util.check_shape_rule(quant_max_shape, 1, 1, 1) + util.check_tensor_shape_size(input_shape) + util.check_tensor_shape_size(alpha_shape) + util.check_tensor_shape_size(quant_max_shape) + + check_list = ["float32", "float16"] + input_dtype = input_dtype.lower() + alpha_dtype = alpha_dtype.lower() + quant_max_dtype = quant_max_dtype.lower() + util.check_dtype_rule(input_dtype, check_list) + util.check_dtype_rule(alpha_dtype, check_list) + util.check_dtype_rule(quant_max_dtype, check_list) + + shape_c = [1] * len(input_shape) + shape_c[channel_axis_] = alpha.get("ori_shape")[0] + if input_x_format == "NC1HWC0" and channel_axis_ == 1: + shape_c = alpha.get("shape") + + input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype) + alpha_data = tvm.placeholder(shape_c, name="alpha_data", dtype=alpha_dtype) + quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) + res = fake_learned_scale_quant_perchannel_compute(input_data, alpha_data, quant_max_data, neg_trunc, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [input_data, alpha_data, quant_max_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list, + "bool_storage_as_1bit": False} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py new file mode 100644 index 00000000000..e8e1ec10c58 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py @@ -0,0 +1,191 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerChannelGradD op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +NEG_SCALAR_MIN_FP16 = -(2 ** (-24)) +NEG_SCALAR_MIN_FP32 = -(2 ** (-126)) +SCALAR_MIN_FP16 = 2 ** (-24) +SCALAR_MIN_FP32 = 2 ** (-126) + +fake_learned_scale_quant_perchannel_grad_d_op_info = TBERegOp("FakeLearnedScaleQuantPerChannelGradD") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perchannel_grad_d.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perchannel_grad_d") \ + .partial_flag(True) \ + .attr("neg_trunc", "optional", "bool", "all") \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "dout", None, "required", None) \ + .input(1, "input_x", None, "required", None) \ + .input(2, "alpha", None, "required", None) \ + .input(3, "quant_max", None, "required", None) \ + .output(0, "dx", True, "required", "all") \ + .output(1, "dalpha", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perchannel_grad_d_op_info) +def _fake_learned_scale_quant_perchannel_grad_d_tbe(): + """FakeLearnedScaleQuantPerChannelGradD TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perchannel_grad_d") +def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_data, quant_max_data, neg_trunc, + kernel_name="fake_learned_scale_quant_perchannel_grad_d"): + """FakeLearnedScaleQuantPerChannelGradD""" + input_shape = te.lang.cce.util.shape_to_list(input_data.shape) + alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype) + quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype) + + input_x = te.lang.cce.vdiv(input_data, alpha_data) + input_div_alpha = input_x + + if neg_trunc: + input_x = te.lang.cce.round_to(input_x, 1.0, 0.0) + else: + input_x = te.lang.cce.round_to(input_x, 1.0, -1.0) + + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vmul(input_x, quant_max_data), 0.5)) + input_quant = te.lang.cce.vdiv(nudge_input, quant_max_data) + + dtype = input_div_alpha.dtype.lower() + shape = te.lang.cce.util.shape_to_list(input_div_alpha.shape) + + dx = dout + tensor_one = tvm.const(1.0, input_div_alpha.dtype) + tensor_one = te.lang.cce.broadcast(tensor_one, shape) + + #out_of_bounds = te.lang.cce.vcmpsel(te.lang.cce.vabs(input_div_alpha), 1.0, 'gt', 1.0, 0.0) + out_of_upper_bounds = te.lang.cce.vcmpsel(input_div_alpha, 1.0, 'gt', 1.0, 0.0) + if neg_trunc: + out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, 0.0, 'lt', 1.0, 0.0) + else: + out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, -1.0, 'lt', 1.0, 0.0) + out_of_bounds = te.lang.cce.vadd(out_of_lower_bounds, out_of_upper_bounds) + + dx = te.lang.cce.vmul(dx, te.lang.cce.vsub(tensor_one, out_of_bounds)) + + # sign function imp + if dtype == "float32": + data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype) + neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype) + elif dtype == "float16": + data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype) + neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype) + else: + data_min = tvm.const(1, dtype=dtype) + neg_data_min = tvm.const(-1, dtype=dtype) + vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min) + vmin = te.lang.cce.vmins(vmax, data_min) + if dtype == "float32": + # max num of float32 is 2**126 + max_support_fp32 = tvm.const(2 ** 62, dtype=dtype) + res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32) + res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32) + sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype)) + elif dtype == "float16": + # max num of float16 is 2**24 + # but cce can only support 2**12, so use 12/12 to adaptor 24 + max_support_fp16 = tvm.const(2 ** 12, dtype=dtype) + res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16) + sign = te.lang.cce.vmuls(res_mul1, max_support_fp16) + else: + sign = vmin + + # The following lines are equivalent to : + # dalpha_each = dout * sign if out of bounds + # dout * (input_quant - input_div_alpha) if within bounds + + quant_error = te.lang.cce.vsub(input_quant, input_div_alpha) + within_bounds = te.lang.cce.vsub(tensor_one, out_of_bounds) + error_within_bounds = te.lang.cce.vmul(quant_error, within_bounds) + grad_range = te.lang.cce.vmadd(sign, error_within_bounds, out_of_bounds) + dalpha_each = te.lang.cce.vmul(dout, grad_range) + + return [dx, dalpha_each] + + +@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, int, str) +def fake_learned_scale_quant_perchannel_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc, + channel_axis, kernel_name="fake_learned_scale_quant_perchannel_grad_d"): + """FakeLearnedScaleQuantPerChannelGradD""" + input_shape = input_x.get("shape") + input_x_shape_ = input_x.get("ori_shape") + input_x_format = input_x.get("format") + input_dtype = input_x.get("dtype") + alpha_shape = alpha.get("ori_shape") + alpha_dtype = alpha.get("dtype") + quant_max_shape = quant_max.get("ori_shape") + quant_max_dtype = quant_max.get("dtype") + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and input_x_shape_[0] != alpha_shape[0] and input_x_shape_[1] == alpha_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis + + util.check_kernel_name(kernel_name) + util.check_shape_rule(input_shape) + util.check_shape_rule(alpha_shape, 1, 1, input_x_shape_[channel_axis_]) + util.check_shape_rule(quant_max_shape, 1, 1, 1) + util.check_tensor_shape_size(input_shape) + util.check_tensor_shape_size(alpha_shape) + util.check_tensor_shape_size(quant_max_shape) + + check_list = ["float32", "float16"] + input_dtype = input_dtype.lower() + alpha_dtype = alpha_dtype.lower() + quant_max_dtype = quant_max_dtype.lower() + util.check_dtype_rule(input_dtype, check_list) + util.check_dtype_rule(alpha_dtype, check_list) + util.check_dtype_rule(quant_max_dtype, check_list) + + shape_c = [1] * len(input_shape) + shape_c[channel_axis_] = alpha.get("ori_shape")[0] + if input_x_format == "NC1HWC0" and channel_axis_ == 1: + shape_c = alpha.get("shape") + + dout_data = tvm.placeholder(input_shape, name="dout", dtype=input_dtype) + input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype) + alpha_data = tvm.placeholder(shape_c, name="alpha_data", dtype=alpha_dtype) + quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) + res = fake_learned_scale_quant_perchannel_grad_d_compute(dout_data, input_data, alpha_data, quant_max_data, + neg_trunc, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_data, input_data, alpha_data, quant_max_data] + list(res) + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py new file mode 100644 index 00000000000..90a40457b19 --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerChannelGradDReduce op""" +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + + +fake_learned_scale_quant_perchannel_grad_d_reduce_op_info = TBERegOp("FakeLearnedScaleQuantPerChannelGradDReduce") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perchannel_grad_d_reduce.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perchannel_grad_d_reduce") \ + .partial_flag(True) \ + .attr("channel_axis", "optional", "int", "all") \ + .input(0, "dout_alpha", None, "required", None) \ + .output(0, "dalpha", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perchannel_grad_d_reduce_op_info) +def _fake_learned_scale_quant_perchannel_grad_d_reduce_tbe(): + """FakeLearnedScaleQuantPerChannelGradDReduce TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perchannel_grad_d_reduce") +def fake_learned_scale_quant_perchannel_grad_d_reduce_compute(dout_alpha_data, dout_alpha, channel_axis, + kernel_name="fake_learned_scale_quant_perchannel_" + "grad_d_reduce"): + """FakeLearnedScaleQuantPerChannelGradDReduce""" + dout_alpha_shape = dout_alpha.get("shape") + axis = list(range(len(dout_alpha_shape))) + axis.remove(channel_axis) + dalpha = te.lang.cce.sum(dout_alpha_data, axis, False) + return dalpha + + +@util.check_input_type(dict, dict, int, str) +def fake_learned_scale_quant_perchannel_grad_d_reduce(dout_alpha, dalpha, channel_axis, + kernel_name="fake_learned_scale_quant_perchannel_grad_d_reduce"): + """FakeLearnedScaleQuantPerChannelGradDReduce""" + + dout_alpha_shape = dout_alpha.get("shape") + dout_alpha_dtype = dout_alpha.get("dtype") + + util.check_kernel_name(kernel_name) + util.check_shape_rule(dout_alpha_shape) + util.check_tensor_shape_size(dout_alpha_shape) + + check_list = ["float32", 'float16'] + dout_alpha_dtype = dout_alpha_dtype.lower() + util.check_dtype_rule(dout_alpha_dtype, check_list) + + dout_alpha_data = tvm.placeholder(dout_alpha_shape, name="dout_alpha", dtype=dout_alpha_dtype) + res = fake_learned_scale_quant_perchannel_grad_d_reduce_compute(dout_alpha_data, dout_alpha, + channel_axis, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_alpha_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py new file mode 100644 index 00000000000..2c0be71b5db --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py @@ -0,0 +1,117 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerLayer op""" +from functools import reduce as functools_reduce +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fake_learned_scale_quant_perlayer_op_info = TBERegOp("FakeLearnedScaleQuantPerLayer") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perlayer.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perlayer") \ + .partial_flag(True) \ + .attr("neg_trunc", "optional", "bool", "all") \ + .input(0, "input_x", None, "required", None) \ + .input(1, "alpha", None, "required", None) \ + .input(2, "quant_max", None, "required", None) \ + .output(0, "out", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perlayer_op_info) +def _fake_learned_scale_quant_perlayer_tbe(): + """FakeLearnedScaleQuantPerLayer TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perlayer") +def fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_data, neg_trunc, + kernel_name="fake_learned_scale_quant_perlayer"): + """FakeLearnedScaleQuantPerLayer""" + input_shape = te.lang.cce.util.shape_to_list(input_data.shape) + alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype) + quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype) + + input_x = te.lang.cce.vdiv(input_data, alpha_data) + + if neg_trunc: + input_x = te.lang.cce.round_to(input_x, 1.0, 0.0) + else: + input_x = te.lang.cce.round_to(input_x, 1.0, -1.0) + + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vmul(input_x, quant_max_data), 0.5)) + input_quant = te.lang.cce.vdiv(nudge_input, quant_max_data) + res = te.lang.cce.vmul(input_quant, alpha_data) + + return res + + +@util.check_input_type(dict, dict, dict, dict, bool, str) +def fake_learned_scale_quant_perlayer(input_x, alpha, quant_max, out, neg_trunc, + kernel_name="fake_learned_scale_quant_perlayer"): + """FakeLearnedScaleQuantPerLayer""" + input_shape = input_x.get("shape") + input_dtype = input_x.get("dtype") + alpha_shape = alpha.get("ori_shape") + alpha_dtype = alpha.get("dtype") + quant_max_shape = quant_max.get("ori_shape") + quant_max_dtype = quant_max.get("dtype") + + alpha_shape = util.scalar2tensor_one(alpha_shape) + quant_max_shape = util.scalar2tensor_one(quant_max_shape) + util.check_kernel_name(kernel_name) + util.check_shape_rule(input_shape) + util.check_shape_rule(alpha_shape, 1, 1, 1) + util.check_shape_rule(quant_max_shape, 1, 1, 1) + util.check_tensor_shape_size(input_shape) + util.check_tensor_shape_size(alpha_shape) + util.check_tensor_shape_size(quant_max_shape) + + check_list = ["float32", "float16"] + input_dtype = input_dtype.lower() + alpha_dtype = alpha_dtype.lower() + quant_max_dtype = quant_max_dtype.lower() + util.check_dtype_rule(input_dtype, check_list) + util.check_dtype_rule(alpha_dtype, check_list) + util.check_dtype_rule(quant_max_dtype, check_list) + + input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) + + input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype) + alpha_data = tvm.placeholder(alpha_shape, name="alpha_data", dtype=alpha_dtype) + quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) + res = fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_data, neg_trunc, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [input_data, alpha_data, quant_max_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list, + "bool_storage_as_1bit": False} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py new file mode 100644 index 00000000000..83f297f614b --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py @@ -0,0 +1,184 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerLayerGradD op""" + +from functools import reduce as functools_reduce +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +NEG_SCALAR_MIN_FP16 = -(2 ** (-24)) +NEG_SCALAR_MIN_FP32 = -(2 ** (-126)) +SCALAR_MIN_FP16 = 2 ** (-24) +SCALAR_MIN_FP32 = 2 ** (-126) + +fake_learned_scale_quant_perlayer_grad_d_op_info = TBERegOp("FakeLearnedScaleQuantPerLayerGradD") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perlayer_grad_d.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perlayer_grad_d") \ + .partial_flag(True) \ + .attr("neg_trunc", "optional", "bool", "all") \ + .input(0, "dout", None, "required", None) \ + .input(1, "input_x", None, "required", None) \ + .input(2, "alpha", None, "required", None) \ + .input(3, "quant_max", None, "required", None) \ + .output(0, "dx", True, "required", "all") \ + .output(1, "dalpha", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perlayer_grad_d_op_info) +def _fake_learned_scale_quant_perlayer_grad_d_tbe(): + """FakeLearnedScaleQuantPerLayerGradD TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perlayer_grad_d") +def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_data, quant_max_data, neg_trunc, + kernel_name="fake_learned_scale_quant_perlayer_grad_d"): + """FakeLearnedScaleQuantPerLayerGradD""" + input_shape = te.lang.cce.util.shape_to_list(input_data.shape) + alpha_data = te.lang.cce.broadcast(alpha_data, input_shape, input_data.dtype) + quant_max_data = te.lang.cce.broadcast(quant_max_data, input_shape, input_data.dtype) + + input_x = te.lang.cce.vdiv(input_data, alpha_data) + input_div_alpha = input_x + + if neg_trunc: + input_x = te.lang.cce.round_to(input_x, 1.0, 0.0) + else: + input_x = te.lang.cce.round_to(input_x, 1.0, -1.0) + + nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vmul(input_x, quant_max_data), 0.5)) + input_quant = te.lang.cce.vdiv(nudge_input, quant_max_data) + + dtype = input_div_alpha.dtype.lower() + shape = te.lang.cce.util.shape_to_list(input_div_alpha.shape) + + dx = dout + tensor_one = tvm.const(1.0, input_div_alpha.dtype) + tensor_one = te.lang.cce.broadcast(tensor_one, shape) + + #out_of_bounds = te.lang.cce.vcmpsel(te.lang.cce.vabs(input_div_alpha), 1.0, 'gt', 1.0, 0.0) + out_of_upper_bounds = te.lang.cce.vcmpsel(input_div_alpha, 1.0, 'gt', 1.0, 0.0) + if neg_trunc: + out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, 0.0, 'lt', 1.0, 0.0) + else: + out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, -1.0, 'lt', 1.0, 0.0) + out_of_bounds = te.lang.cce.vadd(out_of_lower_bounds, out_of_upper_bounds) + + dx = te.lang.cce.vmul(dx, te.lang.cce.vsub(tensor_one, out_of_bounds)) + + # sign function imp + if dtype == "float32": + data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype) + neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype) + elif dtype == "float16": + data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype) + neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype) + else: + data_min = tvm.const(1, dtype=dtype) + neg_data_min = tvm.const(-1, dtype=dtype) + vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min) + vmin = te.lang.cce.vmins(vmax, data_min) + if dtype == "float32": + # max num of float32 is 2**126 + max_support_fp32 = tvm.const(2 ** 62, dtype=dtype) + res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32) + res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32) + sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype)) + elif dtype == "float16": + # max num of float16 is 2**24 + # but cce can only support 2**12, so use 12/12 to adaptor 24 + max_support_fp16 = tvm.const(2 ** 12, dtype=dtype) + res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16) + sign = te.lang.cce.vmuls(res_mul1, max_support_fp16) + else: + sign = vmin + + # The following lines are equivalent to : + # dalpha_each = dout * sign if out of bounds + # dout * (input_quant - input_div_alpha) if within bounds + + quant_error = te.lang.cce.vsub(input_quant, input_div_alpha) + within_bounds = te.lang.cce.vsub(tensor_one, out_of_bounds) + error_within_bounds = te.lang.cce.vmul(quant_error, within_bounds) + grad_range = te.lang.cce.vmadd(sign, error_within_bounds, out_of_bounds) + dalpha_each = te.lang.cce.vmul(dout, grad_range) + + return [dx, dalpha_each] + + +@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str) +def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc, + kernel_name="fake_learned_scale_quant_perlayer_grad_d"): + """FakeLearnedScaleQuantPerLayerGradD""" + input_shape = input_x.get("shape") + input_dtype = input_x.get("dtype") + alpha_shape = alpha.get("ori_shape") + alpha_dtype = alpha.get("dtype") + quant_max_shape = quant_max.get("ori_shape") + quant_max_dtype = quant_max.get("dtype") + + alpha_shape = util.scalar2tensor_one(alpha_shape) + quant_max_shape = util.scalar2tensor_one(quant_max_shape) + util.check_kernel_name(kernel_name) + util.check_shape_rule(input_shape) + util.check_shape_rule(alpha_shape, 1, 1, 1) + util.check_shape_rule(quant_max_shape, 1, 1, 1) + util.check_tensor_shape_size(input_shape) + util.check_tensor_shape_size(alpha_shape) + util.check_tensor_shape_size(quant_max_shape) + + check_list = ["float32", "float16"] + input_dtype = input_dtype.lower() + alpha_dtype = alpha_dtype.lower() + quant_max_dtype = quant_max_dtype.lower() + util.check_dtype_rule(input_dtype, check_list) + util.check_dtype_rule(alpha_dtype, check_list) + util.check_dtype_rule(quant_max_dtype, check_list) + + input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) + + dout_data = tvm.placeholder(input_shape, name="dout", dtype=input_dtype) + input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype) + alpha_data = tvm.placeholder(alpha_shape, name="alpha_data", dtype=alpha_dtype) + quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) + res = fake_learned_scale_quant_perlayer_grad_d_compute(dout_data, input_data, alpha_data, quant_max_data, + neg_trunc, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_data, input_data, alpha_data, quant_max_data] + list(res) + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py new file mode 100644 index 00000000000..d4af1f6381f --- /dev/null +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FakeLearnedScaleQuantPerLayerGradDReduce op""" + +from functools import reduce as functools_reduce +import te.lang.cce +from te import tvm +from te.platform.fusion_manager import fusion_manager +from topi import generic +from topi.cce import util +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + + +fake_learned_scale_quant_perlayer_grad_d_reduce_op_info = TBERegOp("FakeLearnedScaleQuantPerLayerGradDReduce") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("fake_learned_scale_quant_perlayer_grad_d_reduce.so") \ + .compute_cost(10) \ + .kernel_name("fake_learned_scale_quant_perlayer_grad_d_reduce") \ + .partial_flag(True) \ + .input(0, "dout_alpha", None, "required", None) \ + .output(0, "dalpha", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(fake_learned_scale_quant_perlayer_grad_d_reduce_op_info) +def _fake_learned_scale_quant_perlayer_grad_d_reduce_tbe(): + """FakeLearnedScaleQuantPerLayerGradDReduce TBE register""" + return + + +@fusion_manager.register("fake_learned_scale_quant_perlayer_grad_d_reduce") +def fake_learned_scale_quant_perlayer_grad_d_reduce_compute(dout_alpha, + kernel_name="fake_learned_scale_quant_perlayer_" + "grad_d_reduce"): + """FakeLearnedScaleQuantPerLayerGradDReduce""" + dalpha = te.lang.cce.sum(dout_alpha, 0, False) + + return dalpha + + +@util.check_input_type(dict, dict, str) +def fake_learned_scale_quant_perlayer_grad_d_reduce(dout_alpha, dalpha, + kernel_name="fake_learned_scale_quant_perlayer_grad_d_reduce"): + """FakeLearnedScaleQuantPerLayerGradDReduce""" + + dout_alpha_shape = dout_alpha.get("shape") + dout_alpha_dtype = dout_alpha.get("dtype") + + util.check_kernel_name(kernel_name) + util.check_shape_rule(dout_alpha_shape) + util.check_tensor_shape_size(dout_alpha_shape) + + check_list = ["float32", 'float16'] + dout_alpha_dtype = dout_alpha_dtype.lower() + util.check_dtype_rule(dout_alpha_dtype, check_list) + + input_shape = (functools_reduce(lambda x, y: x * y, dout_alpha_shape[:]),) + + dout_alpha_data = tvm.placeholder(input_shape, name="dout_alpha", dtype=dout_alpha_dtype) + res = fake_learned_scale_quant_perlayer_grad_d_reduce_compute(dout_alpha_data, kernel_name) + + with tvm.target.cce(): + sch = generic.auto_schedule(res) + + tensor_list = [dout_alpha_data, res] + config = {"print_ir": False, + "name": kernel_name, + "tensor_list": tensor_list} + + te.lang.cce.cce_build_code(sch, config) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 9586e3bda15..e097c8df3c4 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -24,6 +24,14 @@ from ...common import dtype as mstype __all__ = ["MinMaxUpdatePerLayer", "MinMaxUpdatePerChannel", + "FakeLearnedScaleQuantPerLayer", + "FakeLearnedScaleQuantPerLayerGrad", + "FakeLearnedScaleQuantPerLayerGradD", + "FakeLearnedScaleQuantPerLayerGradDReduce", + "FakeLearnedScaleQuantPerChannel", + "FakeLearnedScaleQuantPerChannelGrad", + "FakeLearnedScaleQuantPerChannelGradD", + "FakeLearnedScaleQuantPerChannelGradDReduce", "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsGradient", "FakeQuantWithMinMaxVarsPerChannel", @@ -169,6 +177,321 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): return min_type, max_type +class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer): + r""" + Simulates the quantize and dequantize operations of the fake learned scale quant per-layer case in training time. + + Args: + quant_delay (int): Quantilization delay parameter. Before delay step in training time not update + simulate quantization aware function. After delay step in training time begin simulate the aware + quantize function. Default: 0. + neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False. + training (bool): Training the network or not. Default: True. + + Inputs: + - **input_x** (Tensor) : Input tensor that needs to be quantified. + - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`. + - **quant_max** (Tensor) : Value of the quantization range. + + Outputs: + - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> alpha_tensor = Tensor(np.array([6]), mstype.float32) + >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32) + >>> output_tensor = FakeLearnedScaleQuantPerLayer()(input_tensor, alpha_tensor, quant_max_tensor) + """ + @prim_attr_register + def __init__(self, + quant_delay=0, + neg_trunc=False, + training=True): + """init FakeLearnedScaleQuantPerLayer OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer + + self.quant_delay = validator.check_non_negative_int( + quant_delay, 'quant_delay', self.name) + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'], + outputs=['out']) + + def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape): + validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name) + validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) + validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) + return input_x_shape + + def infer_dtype(self, input_x_type, alpha_type, quant_max_type): + if context.get_context('device_target') == "GPU": + valid_dtypes = (mstype.float32,) + else: + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("input_x", "alpha", "quant_max"), + (input_x_type, alpha_type, quant_max_type))) + return input_x_type + + +class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer): + r""" + Performs grad of FakeLearnedScaleQuantPerLayer operation. + + Examples: + >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerLayerGrad() + >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) + >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) + >>> _alpha = Tensor(np.array([6]), mindspore.float32) + >>> _quant_max = Tensor(np.array([127]), mindspore.float32) + >>> result = fake_learned_scale_grad(dout, input_x, _min, _max) + """ + + @prim_attr_register + def __init__(self, + quant_delay=0, + neg_trunc=False): + self.quant_delay = validator.check_non_negative_int( + quant_delay, 'quant_delay', self.name) + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) + + def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): + validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) + validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) + validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) + return dout_shape, alpha_shape + + def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): + if context.get_context('device_target') == "GPU": + valid_dtypes = (mstype.float32,) + else: + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "alpha", "quant_max"), + (dout_type, x_type, alpha_type, quant_max_type))) + return dout_type, alpha_type + + +class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer): + r""" + Performs input grad of FakeLearnedScaleQuantPerLayer operation. + """ + + @prim_attr_register + def __init__(self, + neg_trunc=False): + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) + + def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): + validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) + validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) + validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) + return dout_shape, dout_shape + + def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "alpha", "quant_max"), + (dout_type, x_type, alpha_type, quant_max_type))) + return dout_type, dout_type + + +class FakeLearnedScaleQuantPerLayerGradDReduce(PrimitiveWithInfer): + r""" + Performs alpha grad reduce of FakeLearnedScaleQuantPerLayer operation. + """ + + @prim_attr_register + def __init__(self): + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad_reduce + self.init_prim_io_names( + inputs=['dout_alpha'], outputs=['dalpha']) + + def infer_shape(self, dout_alpha_shape): + return (1,) + + def infer_dtype(self, dout_alpha_type): + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name) + return dout_alpha_type + + +class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer): + r""" + Simulates the quantize and dequantize operations of the fake learned scale quant per-chnnel case in training time. + + Args: + quant_delay (int): Quantilization delay parameter. Before delay step in training time not update + simulate quantization aware function. After delay step in training time begin simulate the aware + quantize function. Default: 0. + neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False. + training (bool): Training the network or not. Default: True. + channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. + + Inputs: + - **input_x** (Tensor) : Input tensor that needs to be quantified. + - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`. + - **quant_max** (Tensor) : Value of the quantization range. + + Outputs: + - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> alpha_tensor = Tensor(np.array([6]*3), mstype.float32) + >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32) + >>> output_tensor = FakeLearnedScaleQuantPerChannel()(input_tensor, alpha_tensor, quant_max_tensor) + """ + ascend_support_x_rank = [2, 4] + + @prim_attr_register + def __init__(self, + quant_delay=0, + neg_trunc=False, + training=True, + channel_axis=1): + """init FakeLearnedScaleQuantPerChannel OP""" + if context.get_context('device_target') == "Ascend": + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel + self.is_ascend = context.get_context('device_target') == "Ascend" + self.quant_delay = validator.check_non_negative_int( + quant_delay, 'quant_delay', self.name) + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.training = validator.check_value_type( + 'training', training, (bool,), self.name) + if self.is_ascend: + self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) + else: + self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) + self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'], + outputs=['out']) + + def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape): + if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank: + raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") + if not self.is_ascend: + validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name) + if len(input_x_shape) == 1: + self.channel_axis = 0 + + validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name) + validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) + return input_x_shape + + def infer_dtype(self, input_x_type, alpha_type, quant_max_type): + if context.get_context('device_target') == "GPU": + valid_dtypes = (mstype.float32,) + else: + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("input_x", "alpha", "quant_max"), + (input_x_type, alpha_type, quant_max_type))) + return input_x_type + + +class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer): + r""" + Performs grad of FakeLearnedScaleQuantPerChannel operation. + + Examples: + >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerChannelGrad() + >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) + >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) + >>> _alpha = Tensor(np.array([6]*2), mindspore.float32) + >>> _quant_max = Tensor(np.array([127]), mindspore.float32) + >>> result = fake_learned_scale_grad(dout, input_x, _min, _max) + """ + + @prim_attr_register + def __init__(self, + quant_delay=0, + neg_trunc=False, + channel_axis=1): + self.quant_delay = validator.check_non_negative_int( + quant_delay, 'quant_delay', self.name) + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) + + def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): + validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) + return dout_shape, alpha_shape + + def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): + if context.get_context('device_target') == "GPU": + valid_dtypes = (mstype.float32,) + else: + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "alpha", "quant_max"), + (dout_type, x_type, alpha_type, quant_max_type))) + return dout_type, alpha_type + + +class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer): + r""" + Performs input grad of FakeLearnedScaleQuantPerChannel operation. + """ + + @prim_attr_register + def __init__(self, + neg_trunc=False, + channel_axis=1): + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad + self.neg_trunc = validator.check_value_type( + 'neg_trunc', neg_trunc, (bool,), self.name) + self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) + self.init_prim_io_names( + inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) + + def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): + validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) + validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) + validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) + return dout_shape, dout_shape + + def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "alpha", "quant_max"), + (dout_type, x_type, alpha_type, quant_max_type))) + return dout_type, dout_type + + +class FakeLearnedScaleQuantPerChannelGradDReduce(PrimitiveWithInfer): + r""" + Performs alpha grad reduce of FakeLearnedScaleQuantPerChannel operation. + """ + + @prim_attr_register + def __init__(self, channel_axis=1): + from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad_reduce + self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) + self.init_prim_io_names( + inputs=['dout_alpha'], outputs=['dalpha']) + + def infer_shape(self, dout_alpha_shape): + return (dout_alpha_shape[self.channel_axis],) + + def infer_dtype(self, dout_alpha_type): + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name) + return dout_alpha_type + + class FakeQuantWithMinMaxVars(PrimitiveWithInfer): r""" Fake-quantize the input by min and max. diff --git a/model_zoo/official/cv/mobilenetv2_quant/README_CN.md b/model_zoo/official/cv/mobilenetv2_quant/README_CN.md index 62c2a36edc8..34c4af3553d 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/README_CN.md +++ b/model_zoo/official/cv/mobilenetv2_quant/README_CN.md @@ -59,13 +59,20 @@ MobileNetV2总体网络架构如下: ## 混合精度 -采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。 +采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) +的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。 以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。 +## 量化步长可学习的量化感知训练 + +参考论文[Learned Step Size Quantization](https://arxiv.org/abs/1902.08153) +对量化感知训练进优化,量化步长由训练学习得到,该特性对低比特量化场景有较好收益,下述中简称为LSQ。 +用户可自由选择是否使用LSQ量化方案进行量化。 + # 环境要求 -- 硬件:昇腾处理器(Ascend) - - 使用昇腾处理器来搭建硬件环境。 +- 硬件 (Ascend/GPU) + - 使用Ascend或GPU来搭建硬件环境。 - 框架 - [MindSpore](https://www.mindspore.cn/install) - 如需查看详情,请参见如下资源 @@ -80,8 +87,10 @@ MobileNetV2总体网络架构如下: ├── mobileNetv2_quant ├── Readme.md # MobileNetV2-Quant相关描述 ├── scripts - │ ├──run_train.sh # 使用昇腾处理器进行训练的shell脚本 - │ ├──run_infer.sh # 使用昇腾处理器进行评估的shell脚本 + │ ├──run_train.sh # 使用Ascend或GPU进行训练的shell脚本 + │ ├──run_infer.sh # 使用Ascend或GPU进行评估的shell脚本 + │ ├──run_lsq_train.sh # 使用Ascend或GPU进行LSQ训练的shell脚本 + │ ├──run_lsq_infer.sh # 使用Ascend或GPU进行LSQ评估的shell脚本 ├── src │ ├──config.py # 参数配置 │ ├──dataset.py # 创建数据集 @@ -91,14 +100,14 @@ MobileNetV2总体网络架构如下: │ ├──utils.py # 提供监控模块 ├── train.py # 训练脚本 ├── eval.py # 评估脚本 - ├── export.py # 导出检查点文件到air/onnx中 + ├── export.py # 导出检查点文件到air/mindir中 ``` ## 脚本参数 在config.py中可以同时配置训练参数和评估参数。 -- 配置MobileNetV2-quant和ImageNet2012数据集。 +- 配置MobileNetV2-quant和ImageNet2012数据集(以Ascend环境配置为例,详见src/config.py )。 ```python 'num_classes':1000 # 数据集类数 @@ -124,15 +133,42 @@ MobileNetV2总体网络架构如下: 使用python或shell脚本开始训练。shell脚本的使用方法如下: +传统量化感知训练(默认): + - bash run_train.sh [Ascend] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\(可选) + - bash run_train.sh [GPU] [DEVICE_ID_LIST] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\(可选) +量化步长可学习的量化感知训练: + +- bash run_lsq_train.sh [Ascend] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH] + +- bash run_lsq_train.sh [GPU] [DEVICE_ID_LIST] [DATASET_PATH] [PRETRAINED_CKPT_PATH] + + `PRETRAINED_CKPT_PATH` 是可选择的选项,如果用户配置该选项,则基于用户指定的预训练模型进行量化,我们更推荐用户基于预训练模型量化。 + + `RANK_TABLE_FILE` 是在Ascned上运行分布式任务时HCCL的配置文件。 + ### 启动 ``` bash - # 训练示例 - >>> bash run_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt - >>> bash run_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt + # 训练示例-传统量化感知训练(默认) + python: + Ascend: python train.py --device_target Ascend --dataset_path ~/imagenet/train/ + GPU: python train.py --device_target GPU --dataset_path ~/imagenet/train/ + shell: + Ascend: bash run_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt + GPU: bash run_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt + + # 训练示例-量化步长可学习的量化感知训练 + python: + Ascend: python train.py --device_target Ascend --dataset_path ~/imagenet/train/ \ + --pre_trained ~/mobilenet.ckpt --optim_option "LEARNED_SCALE" + GPU: python train.py --device_target GPU --dataset_path ~/imagenet/train/ \ + --pre_trained ~/mobilenet.ckpt --optim_option "LEARNED_SCALE" + shell: + Ascend: bash run_lsq_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt + GPU: bash run_lsq_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt ``` ### 结果 @@ -151,16 +187,38 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 ### 用法 -使用python或shell脚本开始训练。shell脚本的使用方法如下: +使用python或shell脚本开始评估。shell脚本的使用方法如下: -- Ascend: sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +传统量化感知训练(默认): + +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] + +量化步长可学习的量化感知训练: + +- Ascend: sh run_lsq_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_lsq_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] ### 启动 ```bash -# 推理示例 - shell: - Ascend: sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt +# 推理示例-传统量化感知训练(默认) +python: + Ascend: python eval.py --device_target Ascend --dataset_path [VAL_DATASET_PATH] --checkpoint_path ~/train/mobilenet-60_1601.ckpt + GPU: python eval.py --device_target GPU --dataset_path [VAL_DATASET_PATH] --checkpoint_path ~/train/mobilenet-60_1601.ckpt +shell: + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + +# 推理示例-量化步长可学习的量化感知训练 +python: + Ascend: python eval.py --device_target Ascend --dataset_path ~/imagenet/val/ \ + --checkpoint_path ~/train/mobilenet-60_1601.ckpt --optim_option "LEARNED_SCALE" + GPU: python eval.py --device_target GPU --dataset_path ~/imagenet/val/ \ + --checkpoint_path ~/train/mobilenet-60_1601.ckpt --optim_option "LEARNED_SCALE" +shell: + Ascend: sh run_lsq_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + GPU: sh run_lsq_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt ``` > 训练过程中可以生成检查点。 @@ -173,45 +231,59 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 result:{'acc':0.71976314102564111} ``` +## 模型导出 + +```shell +python export.py --checkpoint_path [CKPT_PATH] --file_format [EXPORT_FORMAT] --device_target [PLATFORM] --optim_option [OptimizeOption] +``` + +`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"]. + +`OptimizeOption` 可选 ["QAT", "LEARNED_SCALE"]. + # 模型描述 ## 性能 ### 训练性能 -| 参数 | MobilenetV2 | -| -------------------------- | ---------------------------------------------------------- | -| 模型版本 | V2 | -| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | -| 上传日期 | 2020-06-06 | -| MindSpore版本 | 0.3.0 | -| 数据集 | ImageNet | -| 训练参数 | src/config.py | -| 优化器 | Momentum | -| 损失函数 | Softmax交叉熵 | -| 输出 | ckpt文件 | -| 损失 | 1.913 | -| 准确率 | | -| 总时长 | 16 h | -| 参数(M) | batch_size=192, epoch=60 | -| 微调检查点 | | -| 推理模型 | | +| 参数 | MobilenetV2 | MobilenetV2 | +| ---------------| ----------------------------------------| ---------------------------- | +| 模型版本 | V2 | V2 | +| 量化方案 | 传统量化感知训练(默认) |量化步长可学习的量化感知训练 | +| 量化策略 | W:8bit, A:8bit | W:4bit (首尾层为 8bit), A:8bit| +| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 |Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | +| 上传日期 | 2020-06-06 |2021-04-30 | +| MindSpore版本 | 0.3.0 |1.3.0 | +| 数据集 | ImageNet |ImageNet | +| 训练参数 | src/config.py |src/config.py | +| 优化器 | Momentum |Momentum | +| 损失函数 | Softmax交叉熵 |Softmax交叉熵 | +| 输出 | ckpt文件 |ckpt文件 | +| 损失 | 1.913 | | +| 准确率 | | | +| 总时长 | 16 h | | +| 参数(M) | batch_size=192, epoch=60 |batch_size=192, epoch=40 | +| 微调检查点 | | | +| 推理模型 | | | #### 评估性能 -| 参数 | | -| -------------------------- | ----------------------------- | -| 模型版本 | V2 | -| 资源 | Ascend 910;系统 Euler2.8 | -| 上传日期 | 2020-06-06 | -| MindSpore版本 | 0.3.0 | -| 数据集 | ImageNet, 1.2W | -| 批次大小 | 130(8P) | -| 输出 | 概率 | -| 准确率 | ACC1[71.78%] ACC5[90.90%] | -| 速度 | 200毫秒/步 | -| 总时长 | 5分钟 | -| 推理模型 | | +| 参数 | | | +| ------------------ | --------------------------|------------------------------ | +| 模型版本 | V2 |V2 | +| 量化方案 | 传统量化感知训练(默认) |量化步长可学习的量化感知训练 | +| 量化策略 | W:8bit, A:8bit | W:4bit (首尾层为 8bit), A:8bit| +| 资源 | Ascend 910;系统 Euler2.8 | Ascend 910;系统 Euler2.8 | +| 上传日期 | 2020-06-06 |2021-04-30 | +| MindSpore版本 | 0.3.0 | 1.3.0 | +| 数据集 | ImageNet, 1.2W | ImageNet, 1.2W | +| 批次大小 | 130(8P) | | +| 输出 | 概率 | 概率 | +| 准确率 | ACC1[71.78%] ACC5[90.90%] | | +| 速度 | 200毫秒/步 | | +| 总时长 | 5分钟 | | +| 推理模型 | | | # 随机情况说明 diff --git a/model_zoo/official/cv/mobilenetv2_quant/Readme.md b/model_zoo/official/cv/mobilenetv2_quant/Readme.md index 7eaa8480417..3e31b8b4657 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/Readme.md +++ b/model_zoo/official/cv/mobilenetv2_quant/Readme.md @@ -49,13 +49,20 @@ Dataset used: [imagenet](http://www.image-net.org/) The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. +## [Learned Step Size Quantization](#contents) + +Inspired by paper [Learned Step Size Quantization](https://arxiv.org/abs/1902.08153) +, we proposed an optimize option, whose quantization scale is learned during the fine-tune process. +This feature has good benefits for low bits quantization scenarios, which is referred to as LSQ. +Users are free to choose whether to use the LEARNED_SCALE optimize option for quantization. + # [Environment Requirements](#contents) -- Hardware:Ascend - - Prepare hardware environment with Ascend. +- Hardware (Ascend/GPU) + - Prepare hardware environment with Ascend or GPU. - Framework - [MindSpore](https://www.mindspore.cn/install/en) -- For more information, please check the resources below +- For more information, please check the resources below: - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) @@ -67,8 +74,10 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil ├── mobileNetv2_quant ├── Readme.md # descriptions about MobileNetV2-Quant ├── scripts - │ ├──run_train.sh # shell script for train on Ascend - │ ├──run_infer.sh # shell script for evaluation on Ascend + │ ├──run_train.sh # shell script for train on Ascend or GPU + │ ├──run_infer.sh # shell script for evaluation on Ascend or GPU + │ ├──run_lsq_train.sh # shell script for train (using the LEARNED_SCALE optimize option) on Ascend or GPU + │ ├──run_lsq_infer.sh # shell script for evaluation (using the LEARNED_SCALE optimize option) on Ascend or GPU ├── src │ ├──config.py # parameter configuration │ ├──dataset.py # creating dataset @@ -85,7 +94,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil Parameters for both training and evaluation can be set in config.py -- config for MobileNetV2-quant, ImageNet2012 dataset +- config for MobileNetV2-quant, ImageNet2012 dataset(We take the environment configuration of ascend as an example here, and you will get more detail in src/config.py) ```python 'num_classes': 1000 # the number of classes in the dataset @@ -111,15 +120,44 @@ Parameters for both training and evaluation can be set in config.py You can start training using python or shell scripts. The usage of shell scripts as follows: +For quantization aware training (default): + - bash run_train.sh [Ascend] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\(optional) - bash run_train.sh [GPU] [DEVICE_ID_LIST] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\(optional) +For Learned Step Size Quantization: + +- bash run_lsq_train.sh [Ascend] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH] +- bash run_lsq_train.sh [GPU] [DEVICE_ID_LIST] [DATASET_PATH] [PRETRAINED_CKPT_PATH] + +`PRETRAINED_CKPT_PATH` is optional. If it is given, quantization is based on the specified pre training ckpt file. We recommend users to execute quantization based on the pre training ckpt file. + +`RANK_TABLE_FILE` is HCCL configuration file when running on Ascend. +> The common restrictions on using the distributed service are as follows. For details, see the HCCL documentation. +> +> - In a single-node system, a cluster of 1, 2, 4, or 8 devices is supported. In a multi-node system, a cluster of 8 x N devices is supported. +> - Each host has four devices numbered 0 to 3 and four devices numbered 4 to 7 deployed on two different networks. During training of 2 or 4 devices, the devices must be connected and clusters cannot be created across networks. + ### Launch ``` bash - # training example - >>> bash run_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt - >>> bash run_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt + # training example for quantization aware training (default) + python: + Ascend: python train.py --device_target Ascend --dataset_path ~/imagenet/train/ + GPU: python train.py --device_target GPU --dataset_path ~/imagenet/train/ + shell: + Ascend: bash run_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt + GPU: bash run_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt + + # training example for Learned Step Size Quantization + python: + Ascend: python train.py --device_target Ascend --dataset_path ~/imagenet/train/ \ + --pre_trained ~/mobilenet.ckpt --optim_option "LEARNED_SCALE" + GPU: python train.py --device_target GPU --dataset_path ~/imagenet/train/ \ + --pre_trained ~/mobilenet.ckpt --optim_option "LEARNED_SCALE" + shell: + Ascend: bash run_lsq_train.sh Ascend ~/hccl_4p_0123_x.x.x.x.json ~/imagenet/train/ ~/mobilenet.ckpt + GPU: bash run_lsq_train.sh GPU 1,2 ~/imagenet/train/ ~/mobilenet.ckpt ``` ### Result @@ -138,16 +176,40 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 ### Usage -You can start training using python or shell scripts. The usage of shell scripts as follows: +You can start evaluating using python or shell scripts. The usage of shell scripts as follows: -- Ascend: sh run_infer_quant.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +For quantization aware training (default): + +- Ascend: sh run_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] + +For Learned Step Size Quantization: + +- Ascend: sh run_lsq_infer.sh Ascend [DATASET_PATH] [CHECKPOINT_PATH] +- GPU: sh run_lsq_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] ### Launch -``` bash -# infer example - shell: - Ascend: sh run_infer_quant.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt +```bash +# training example for quantization aware training (default) +python: + Ascend: python eval.py --device_target Ascend --dataset_path [VAL_DATASET_PATH] --checkpoint_path ~/train/mobilenet-60_1601.ckpt + GPU: python eval.py --device_target GPU --dataset_path [VAL_DATASET_PATH] --checkpoint_path ~/train/mobilenet-60_1601.ckpt + +shell: + Ascend: sh run_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + +# training example for Learned Step Size Quantization +python: + Ascend: python eval.py --device_target Ascend --dataset_path ~/imagenet/val/ \ + --checkpoint_path ~/train/mobilenet-60_1601.ckpt --optim_option "LEARNED_SCALE" + GPU: python eval.py --device_target GPU --dataset_path ~/imagenet/val/ \ + --checkpoint_path ~/train/mobilenet-60_1601.ckpt --optim_option "LEARNED_SCALE" + +shell: + Ascend: sh run_lsq_infer.sh Ascend ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt + GPU: sh run_lsq_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-60_1601.ckpt ``` > checkpoint can be produced in training process. @@ -160,45 +222,58 @@ Inference result will be stored in the example path, you can find result like th result: {'acc': 0.71976314102564111} ``` +## [Model Export](#contents) + +```shell +python export.py --checkpoint_path [CKPT_PATH] --file_format [EXPORT_FORMAT] --device_target [PLATFORM] --optim_option [OptimizeOption] +``` + +`EXPORT_FORMAT` should be in ["AIR", "MINDIR"]. +`OptimizeOption` should be in ["QAT", "LEARNED_SCALE"]. + # [Model description](#contents) ## [Performance](#contents) ### Training Performance -| Parameters | MobilenetV2 | -| -------------------------- | ---------------------------------------------------------- | -| Model Version | V2 | -| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 | -| uploaded Date | 06/06/2020 | -| MindSpore Version | 0.3.0 | -| Dataset | ImageNet | -| Training Parameters | src/config.py | -| Optimizer | Momentum | -| Loss Function | SoftmaxCrossEntropy | -| outputs | ckpt file | -| Loss | 1.913 | -| Accuracy | | -| Total time | 16h | -| Params (M) | batch_size=192, epoch=60 | -| Checkpoint for Fine tuning | | -| Model for inference | | +| Parameters | MobilenetV2 | MobilenetV2 | +| -------------------------- | --------------------------------------------------| --------------------------------------------------| +| Model Version | V2 | V2 | +| Optimize Option | QAT | LEARNED_SCALE | +| Quantization Strategy | W:8bit, A:8bit | W:4bit (The first and last layers are 8bit), A:8bit| +| Resource | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 | Ascend 910; cpu 2.60GHz, 192cores; memory 755G; OS Euler2.8 | +| uploaded Date | 06/06/2020 | 04/30/2021 | +| MindSpore Version | 0.3.0 | 1.3.0 | +| Dataset | ImageNet | ImageNet | +| Training Parameters | src/config.py | src/config.py | +| Optimizer | Momentum | Momentum | +| Loss Function | SoftmaxCrossEntropy | SoftmaxCrossEntropy | +| outputs | ckpt file | ckpt file | +| Loss | 1.913 | | +| Accuracy | | | +| Total time | 16h | | +| Params (M) | batch_size=192, epoch=60 | batch_size=192, epoch=40 | +| Checkpoint for Fine tuning | | | +| Model for inference | | | #### Evaluation Performance -| Parameters | | -| -------------------------- | ----------------------------- | -| Model Version | V2 | -| Resource | Ascend 910; OS Euler2.8 | -| uploaded Date | 06/06/2020 | -| MindSpore Version | 0.3.0 | -| Dataset | ImageNet, 1.2W | -| batch_size | 130(8P) | -| outputs | probability | -| Accuracy | ACC1[71.78%] ACC5[90.90%] | -| Speed | 200ms/step | -| Total time | 5min | -| Model for inference | | +| Parameters | | | +| -------------------------- | ----------------------------- | ----------------------------- | +| Model Version | V2 | V2 | +| Optimize Option | QAT | LEARNED_SCALE | +| Quantization Strategy | W:8bit, A:8bit | W:4bit (The first and last layers are 8bit), A:8bit| +| Resource | Ascend 910; OS Euler2.8 | Ascend 910; OS Euler2.8 | +| uploaded Date | 06/06/2020 | 04/30/2021 | +| MindSpore Version | 0.3.0 | 1.3.0 | +| Dataset | ImageNet, 1.2W | ImageNet, 1.2W | +| batch_size | 130(8P) | | +| outputs | probability | probability | +| Accuracy | ACC1[71.78%] ACC5[90.90%] | | +| Speed | 200ms/step | | +| Total time | 5min | | +| Model for inference | | | # [Description of Random Situation](#contents) diff --git a/model_zoo/official/cv/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py index b49c6848e41..5c914f1a1c9 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/eval.py +++ b/model_zoo/official/cv/mobilenetv2_quant/eval.py @@ -21,42 +21,67 @@ from mindspore import context from mindspore import nn from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.compression.common import QuantDtype from mindspore.compression.quant import QuantizationAwareTraining - +from mindspore.compression.quant.quantizer import OptimizeOption from src.mobilenetV2 import mobilenetV2 +from src.mobilenetv2_mix_quant import mobilenetv2_mix_quant from src.dataset import create_dataset -from src.config import config_ascend_quant -from src.config import config_gpu_quant +from src.config import config_ascend_quant, config_gpu_quant, config_lsq_ascend_quant, config_lsq_gpu_quant parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--optim_option', type=str, default="QAT", help='OptimizeOption') args_opt = parser.parse_args() if __name__ == '__main__': config_device_target = None device_id = int(os.getenv('DEVICE_ID')) if args_opt.device_target == "Ascend": - config_device_target = config_ascend_quant + if args_opt.optim_option == "LEARNED_SCALE": + config_device_target = config_lsq_ascend_quant + else: + config_device_target = config_ascend_quant + symmetric_list = [True, False] context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) - symmetric_list = [True, False] elif args_opt.device_target == "GPU": - config_device_target = config_gpu_quant + if args_opt.optim_option == "LEARNED_SCALE": + config_device_target = config_lsq_gpu_quant + else: + config_device_target = config_gpu_quant + symmetric_list = [False, False] context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id, save_graphs=False) - symmetric_list = [False, False] else: raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) - # define fusion network - network = mobilenetV2(num_classes=config_device_target.num_classes) - # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=symmetric_list) + + if args_opt.optim_option == "LEARNED_SCALE": + # define fusion network + network = mobilenetv2_mix_quant(num_classes=config_device_target.num_classes) + # convert fusion network to quantization aware network + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + quant_dtype=(QuantDtype.INT4, QuantDtype.INT8), + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + # define fusion network + network = mobilenetV2(num_classes=config_device_target.num_classes) + # convert fusion network to quantization aware network + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=symmetric_list) network = quantizer.quantize(network) + # define network loss loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') diff --git a/model_zoo/official/cv/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py index 0f2c28b8afa..a0ccc9c86f7 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/export.py +++ b/model_zoo/official/cv/mobilenetv2_quant/export.py @@ -19,26 +19,45 @@ import numpy as np import mindspore from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export +from mindspore.compression.common import QuantDtype from mindspore.compression.quant import QuantizationAwareTraining - +from mindspore.compression.quant.quantizer import OptimizeOption from src.mobilenetV2 import mobilenetV2 +from src.mobilenetv2_mix_quant import mobilenetv2_mix_quant from src.config import config_quant parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--optim_option', type=str, default="QAT", help='OptimizeOption') args_opt = parser.parse_args() if __name__ == '__main__': cfg = config_quant(args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, save_graphs=False) - # define fusion network - network = mobilenetV2(num_classes=cfg.num_classes) - # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) + + if args_opt.optim_option == "LEARNED_SCALE": + # define fusion network + network = mobilenetv2_mix_quant(num_classes=cfg.num_classes) + # convert fusion network to quantization aware network + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + quant_dtype=(QuantDtype.INT4, QuantDtype.INT8), + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + # define fusion network + network = mobilenetV2(num_classes=cfg.num_classes) + # convert fusion network to quantization aware network + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) network = quantizer.quantize(network) # load checkpoint param_dict = load_checkpoint(args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_infer.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_infer.sh new file mode 100644 index 00000000000..92993a51933 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_infer.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 3 ] +then + echo "Ascend: sh run_lsq_infer.sh [PLATFORM] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +# check dataset path +if [ ! -d $2 ] && [ ! -f $2 ] +then + echo "error: DATASET_PATH=$2 is not a directory or file" +exit 1 +fi + +# check checkpoint file +if [ ! -f $3 ] +then + echo "error: CHECKPOINT_PATH=$3 is not a file" +exit 1 +fi + +# set environment +BASEPATH=$(cd "`dirname $0`" || exit; pwd) +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 +if [ -d "../eval" ]; +then + rm -rf ../eval +fi +mkdir ../eval +cd ../eval || exit + +# launch +python ${BASEPATH}/../eval.py \ + --device_target=$1 \ + --dataset_path=$2 \ + --checkpoint_path=$3 \ + --optim_option="LEARNED_SCALE" \ + &> infer.log & # dataset val folder path diff --git a/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_train.sh b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_train.sh new file mode 100644 index 00000000000..6e10f28de1a --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/scripts/run_lsq_train.sh @@ -0,0 +1,177 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + + +get_gpu_device_num(){ + + #device_list=(${1//,/ }) + IFS=',' read -ra device_list <<<"$1" + device_used=(0 0 0 0 0 0 0 0) + device_num=0 + for var in "${device_list[@]}" + do + if [ $((var)) -lt 0 ] || [ $((var)) -ge 8 ] + then + echo "error: device id=${var} is incorrect, device id must be in range [0,8), please check your device id list!" + exit 1 + fi + + if [ ${device_used[$((var))]} -eq 0 ] + then + device_used[ $((var)) ]=1 + device_num=$((device_num+1)) + fi + done + + echo ${device_num} +} + + +run_ascend(){ + + if [ $# -gt 4 ] || [ $# -lt 4 ] + then + echo "Usage: bash run_lsq_train.sh [Ascend] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\n " + exit 1 + fi + PATH1=$(get_real_path $2) + PATH2=$(get_real_path $3) + PATH3=$(get_real_path $4) + + if [ ! -f $PATH1 ] + then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" + exit 1 + fi + + if [ ! -d $PATH2 ] + then + echo "error: DATASET_PATH=$PATH2 is not a directory" + exit 1 + fi + + if [ $# == 4 ] && [ ! -f $PATH3 ] + then + echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" + exit 1 + fi + + + + #rank_file_name=${2##*/} + #IFS='_' read -ra array <<<"${rank_file_name}" + #device_id_list=${array[2]} + #first_device=${device_id_list:0:1} + #last_device=${device_list:${#device_list}-1:1} + #device_num=${#device_id_list} + cat $2 | awk -F "[device_id]" '/device_id/{print$0}' >temp.log + array=$(cat temp.log | awk -F "[:]" '/device_id/{print$2}') + + IFS=" " read -ra device_list <<<$array + first_device=${device_list[0]:1:1} + #device_num=${#device_list[*]} + device_num=$(cat temp.log | wc -l) + rm temp.log + ulimit -u unlimited + export DEVICE_NUM=${device_num} + export RANK_SIZE=${device_num} + export RANK_TABLE_FILE=$PATH1 + + export SERVER_ID=0 + rank_start=$((DEVICE_NUM * SERVER_ID)) + + rm -rf ./train + mkdir ./train + for((i=0; i<${device_num}; i++)) + do + export DEVICE_ID=$((first_device+i)) + export RANK_ID=$((rank_start + i)) + mkdir ./train/device$i + cp ../*.py ./train/device$i + cp *.sh ./train/device$i + cp -r ../src ./train/device$i + cd ./train/device$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py --device_target=$1 --dataset_path=$PATH2 --pre_trained=$PATH3 \ + --optim_option="LEARNED_SCALE" &> train.log & + + cd ../.. || exit + done +} + +run_gpu(){ + if [ $# -gt 4 ] || [ $# -lt 4 ] + then + echo "Usage: bash run_lsq_train.sh [GPU] [DEVICE_ID_LIST] [DATASET_PATH] [PRETRAINED_CKPT_PATH]\n " + exit 1 + fi + + PATH1=$(get_real_path $3) + PATH2=$(get_real_path $4) + + if [ ! -d $PATH1 ] + then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 + fi + + if [ $# == 4 ] && [ ! -f $PATH2 ] + then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" + exit 1 + fi + + device_num=$(get_gpu_device_num $2) + + ulimit -u unlimited + export DEVICE_NUM=${device_num} + export RANK_SIZE=${device_num} + export CUDA_VISIBLE_DEVICES=$2 + + rm -rf ./train + mkdir ./train + cp ../*.py ./train + cp *.sh ./train + cp -r ../src ./train + cd ./train || exit + echo "start training" + env > env.log + + mpirun --allow-run-as-root -n ${RANK_SIZE} --output-filename log_output --merge-stderr-to-stdout \ + python train.py --device_target=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 \ + --optim_option="LEARNED_SCALE" &> train.log & + + cd .. +} + + +if [ $1 = "Ascend" ] ; then + run_ascend "$@" +elif [ $1 = "GPU" ] ; then + run_gpu "$@" +else + echo "Unsupported device target: $1" +fi; + diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/config.py b/model_zoo/official/cv/mobilenetv2_quant/src/config.py index ce2e613c187..2aaa8c954fb 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/src/config.py +++ b/model_zoo/official/cv/mobilenetv2_quant/src/config.py @@ -55,6 +55,45 @@ config_gpu_quant = ed({ "save_checkpoint_path": "./checkpoint", }) +config_lsq_gpu_quant = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 128, + "epoch_size": 60, + "start_epoch": 200, + "warmup_epochs": 0, + "lr": 0.05, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", +}) + +config_lsq_ascend_quant = ed({ + "num_classes": 1000, + "image_height": 224, + "image_width": 224, + "batch_size": 192, + "data_load_mode": "mindata", + "epoch_size": 40, + "start_epoch": 200, + "warmup_epochs": 0, + "lr": 0.05, + "momentum": 0.9, + "weight_decay": 4e-5, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 300, + "save_checkpoint_path": "./checkpoint", +}) + def config_quant(device_target): if device_target not in ["Ascend", "GPU"]: raise ValueError("Unsupported device target: {}.".format(device_target)) diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetv2_mix_quant.py b/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetv2_mix_quant.py new file mode 100644 index 00000000000..c982e070c50 --- /dev/null +++ b/model_zoo/official/cv/mobilenetv2_quant/src/mobilenetv2_mix_quant.py @@ -0,0 +1,414 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""MobileNetV2 model define""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops.operations import Add +from mindspore import Tensor +from mindspore.compression.common import QuantDtype +from mindspore.compression.quant import create_quant_config + +__all__ = ['mobilenetv2_mix_quant'] + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class GlobalAvgPooling(nn.Cell): + """ + Global avg pooling definition. + + Args: + + Returns: + Tensor, output tensor. + + Examples: + >>> GlobalAvgPooling() + """ + + def __init__(self): + super(GlobalAvgPooling, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + + def construct(self, x): + x = self.mean(x, (2, 3)) + return x + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + self.conv = nn.Conv2dBnAct(in_planes, out_planes, kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + group=groups, + has_bn=True, + activation='relu6') + + def construct(self, x): + output = self.conv(x) + return output + +quant_config = create_quant_config(per_channel=(True, False), symmetric=(True, True), narrow_range=(True, True), + mode="LEARNED_SCALE") +class FirstQuantLayer(nn.Cell): + """ + The first quantization layer, which is fixed to 8bit. + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1): + super(FirstQuantLayer, self).__init__() + padding = (kernel_size - 1) // 2 + in_channels = in_planes + out_channels = out_planes + conv_inner = nn.Conv2dBnFoldQuantOneConv(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode='pad', + padding=padding, + quant_config=quant_config, + quant_dtype=QuantDtype.INT8) + activation = nn.ActQuant(activation=nn.ReLU6(), + quant_config=quant_config, + quant_dtype=QuantDtype.INT8) + self.features = nn.SequentialCell([conv_inner, activation]) + + def construct(self, x): + output = self.features(x) + return output + +class LastQuantLayer(nn.Cell): + """ + The last quantization layer, which is fixed to 8bit. + """ + + def __init__(self, in_channels, out_channels, has_bias, has_bn): + super(LastQuantLayer, self).__init__() + + self.dense_inner = nn.DenseQuant(in_channels, + out_channels, + has_bias=has_bias, + quant_config=quant_config, + quant_dtype=QuantDtype.INT8) + self.fake_quant_act = nn.FakeQuantWithMinMaxObserver(min_init=-16, + max_init=16, + ema=True, + quant_dtype=QuantDtype.INT8, + per_channel=False, + symmetric=True, + narrow_range=True, + mode="LEARNED_SCALE") + + def construct(self, x): + output = self.dense_inner(x) + output = self.fake_quant_act(output) + return output + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, + stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2dBnAct(hidden_dim, oup, kernel_size=1, stride=1, has_bn=True) + ]) + self.conv = nn.SequentialCell(layers) + self.add = Add() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + return x + + +class MobileNetV2Backbone(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (int): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8, + input_channel=32, last_channel=1280): + super(MobileNetV2Backbone, self).__init__() + block = InvertedResidual + # setting of inverted residual blocks + self.cfgs = inverted_residual_setting + if inverted_residual_setting is None: + self.cfgs = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [FirstQuantLayer(3, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1)) + # make it nn.CellList + self.features = nn.SequentialCell(features) + self._initialize_weights() + + def construct(self, x): + x = self.features(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), + m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) + + @property + def get_features(self): + return self.features + + +class MobileNetV2Head(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (int): Number of classes. Default is 1000. + has_dropout (bool): Is dropout used. Default is false + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(num_classes=1000) + """ + + def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"): + super(MobileNetV2Head, self).__init__() + # mobilenet head + head = ([GlobalAvgPooling(), LastQuantLayer(input_channel, num_classes, has_bias=True, has_bn=False)] + if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), LastQuantLayer(input_channel, num_classes, + has_bias=True, has_bn=False)]) + self.head = nn.SequentialCell(head) + self.need_activation = True + if activation == "Sigmoid": + self.activation = P.Sigmoid() + elif activation == "Softmax": + self.activation = P.Softmax() + else: + self.need_activation = False + self._initialize_weights() + + def construct(self, x): + x = self.head(x) + if self.need_activation: + x = self.activation(x) + return x + + def _initialize_weights(self): + """ + Initialize weights. + + Args: + + Returns: + None. + + Examples: + >>> _initialize_weights() + """ + self.init_parameters_data() + for _, m in self.cells_and_names(): + if isinstance(m, nn.Dense): + m.weight.set_data(Tensor(np.random.normal( + 0, 0.01, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + @property + def get_head(self): + return self.head + + +class MobileNetV2(nn.Cell): + """ + MobileNetV2 architecture. + + Args: + class_num (int): number of classes. + width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. + has_dropout (bool): Is dropout used. Default is false + inverted_residual_setting (list): Inverted residual settings. Default is None + round_nearest (list): Channel round to . Default is 8 + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2(backbone, head) + """ + + def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \ + round_nearest=8, input_channel=32, last_channel=1280): + super(MobileNetV2, self).__init__() + self.backbone = MobileNetV2Backbone(width_mult=width_mult, \ + inverted_residual_setting=inverted_residual_setting, \ + round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_features + self.head = MobileNetV2Head(input_channel=self.backbone.out_channel, num_classes=num_classes, \ + has_dropout=has_dropout).get_head + + def construct(self, x): + x = self.backbone(x) + x = self.head(x) + return x + + +class MobileNetV2Combine(nn.Cell): + """ + MobileNetV2Combine architecture. + + Args: + backbone (Cell): the features extract layers. + head (Cell): the fully connected layers. + Returns: + Tensor, output tensor. + + Examples: + >>> MobileNetV2Combine(backbone, head) + """ + + def __init__(self, backbone, head): + super(MobileNetV2Combine, self).__init__(auto_prefix=False) + self.backbone = backbone + self.head = head + + def construct(self, x): + x = self.backbone(x) + x = self.head(x) + return x + + +def mobilenet_v2(backbone, head): + return MobileNetV2Combine(backbone, head) + + +def mobilenetv2_mix_quant(num_classes): + """ + MobileNetV2 quantization model, the first and last layers are fixed to 8bit by manual quantization cell, + and others `QuantDtype` are seted by the `QuantizationAwareTraining` API. + """ + backbone_net = MobileNetV2Backbone() + head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, + num_classes=num_classes) + net = mobilenet_v2(backbone_net, head_net) + return net diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py index 65a03bae318..0ab98d590a7 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/train.py +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -26,15 +26,18 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_checkpoint from mindspore.communication.management import init, get_group_size, get_rank +from mindspore.compression.common import QuantDtype from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quantizer import OptimizeOption from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.common import set_seed from src.dataset import create_dataset from src.lr_generator import get_lr from src.utils import Monitor, CrossEntropyWithLabelSmooth -from src.config import config_ascend_quant, config_gpu_quant +from src.config import config_ascend_quant, config_gpu_quant, config_lsq_ascend_quant, config_lsq_gpu_quant from src.mobilenetV2 import mobilenetV2 +from src.mobilenetv2_mix_quant import mobilenetv2_mix_quant set_seed(1) @@ -42,6 +45,8 @@ parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') parser.add_argument('--device_target', type=str, default=None, help='Run device target') +parser.add_argument('--optim_option', type=str, default="QAT", help='If OptimizeOption is set to LEARNED_SCALE,' + 'the learned scale quant process is executed.') args_opt = parser.parse_args() if args_opt.device_target == "Ascend": @@ -66,7 +71,11 @@ else: def train_on_ascend(): - config = config_ascend_quant + if args_opt.optim_option == "LEARNED_SCALE": + config = config_lsq_ascend_quant + else: + config = config_ascend_quant + print("training args: {}".format(args_opt)) print("training configure: {}".format(config)) print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) @@ -80,7 +89,10 @@ def train_on_ascend(): init() # define network - network = mobilenetV2(num_classes=config.num_classes) + if args_opt.optim_option == "LEARNED_SCALE": + network = mobilenetv2_mix_quant(num_classes=config.num_classes) + else: + network = mobilenetV2(num_classes=config.num_classes) # define loss if config.label_smooth > 0: loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes) @@ -99,10 +111,22 @@ def train_on_ascend(): param_dict = load_checkpoint(args_opt.pre_trained) load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[True, False], - one_conv_fold=True) + if args_opt.optim_option == "LEARNED_SCALE": + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + quant_dtype=(QuantDtype.INT4, QuantDtype.INT8), + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False], + one_conv_fold=True) network = quantizer.quantize(network) # get learning rate @@ -136,12 +160,19 @@ def train_on_ascend(): def train_on_gpu(): - config = config_gpu_quant + if args_opt.optim_option == "LEARNED_SCALE": + config = config_lsq_gpu_quant + else: + config = config_gpu_quant + print("training args: {}".format(args_opt)) print("training configure: {}".format(config)) # define network - network = mobilenetV2(num_classes=config.num_classes) + if args_opt.optim_option == "LEARNED_SCALE": + network = mobilenetv2_mix_quant(num_classes=config.num_classes) + else: + network = mobilenetV2(num_classes=config.num_classes) # define loss if config.label_smooth > 0: loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, @@ -163,11 +194,23 @@ def train_on_gpu(): load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[False, False], - freeze_bn=1000000, - quant_delay=step_size * 2) + if args_opt.optim_option == "LEARNED_SCALE": + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + quant_dtype=(QuantDtype.INT4, QuantDtype.INT8), + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[False, False], + freeze_bn=1000000, + quant_delay=step_size * 2) network = quantizer.quantize(network) # get learning rate diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py index a2981ee21fd..cef21a3b77c 100644 --- a/tests/st/quantization/lenet_quant/test_lenet_quant.py +++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py @@ -27,6 +27,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni from mindspore import load_checkpoint, load_param_into_net, export from mindspore.train import Model from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quantizer import OptimizeOption from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from dataset import create_dataset from config import nonquant_cfg, quant_cfg @@ -58,7 +59,30 @@ def train_lenet(): dataset_sink_mode=True) -def train_lenet_quant(): +def eval_lenet(): + context.set_context(mode=context.GRAPH_MODE, device_target=device_target) + cfg = nonquant_cfg + ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1) + ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt' + # define fusion network + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + # call back and monitor + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + # load quantization aware network checkpoint + param_dict = load_checkpoint(ckpt_path) + not_load_param = load_param_into_net(network, param_dict) + if not_load_param: + raise ValueError("Load param into net fail!") + + print("============== Starting Testing ==============") + acc = model.eval(ds_eval, dataset_sink_mode=True) + print("============== {} ==============".format(acc)) + assert acc['Accuracy'] > 0.98 + + +def train_lenet_quant(optim_option="QAT"): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = quant_cfg ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt' @@ -73,10 +97,21 @@ def train_lenet_quant(): load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(quant_delay=900, - bn_fold=False, - per_channel=[True, False], - symmetric=[True, False]) + if optim_option == "LEARNED_SCALE": + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=False, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + quantizer = QuantizationAwareTraining(quant_delay=900, + bn_fold=False, + per_channel=[True, False], + symmetric=[True, False]) network = quantizer.quantize(network) # define network loss @@ -87,7 +122,7 @@ def train_lenet_quant(): # call back and monitor config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant", config=config_ckpt) + ckpt_callback = ModelCheckpoint(prefix="ckpt_lenet_quant"+optim_option, config=config_ckpt) # define model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) @@ -98,19 +133,30 @@ def train_lenet_quant(): print("============== End Training ==============") -def eval_quant(): +def eval_quant(optim_option="QAT"): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = quant_cfg ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1) - ckpt_path = './ckpt_lenet_quant-10_937.ckpt' + ckpt_path = './ckpt_lenet_quant'+optim_option+'-10_937.ckpt' # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(quant_delay=0, - bn_fold=False, - freeze_bn=10000, - per_channel=[True, False], - symmetric=[True, False]) + if optim_option == "LEARNED_SCALE": + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=False, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) network = quantizer.quantize(network) # define loss @@ -132,17 +178,29 @@ def eval_quant(): print("============== {} ==============".format(acc)) assert acc['Accuracy'] > 0.98 -def export_lenet(): + +def export_lenet(optim_option="QAT"): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = quant_cfg # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(quant_delay=0, - bn_fold=False, - freeze_bn=10000, - per_channel=[True, False], - symmetric=[True, False]) + if optim_option == "LEARNED_SCALE": + quant_optim_otions = OptimizeOption.LEARNED_SCALE + quantizer = QuantizationAwareTraining(bn_fold=False, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=quant_optim_otions) + else: + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) network = quantizer.quantize(network) # export network @@ -155,9 +213,13 @@ def export_lenet(): @pytest.mark.env_onecard def test_lenet_quant(): train_lenet() + eval_lenet() train_lenet_quant() eval_quant() export_lenet() + train_lenet_quant(optim_option="LEARNED_SCALE") + eval_quant(optim_option="LEARNED_SCALE") + export_lenet(optim_option="LEARNED_SCALE") if __name__ == "__main__": diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index c8e46072fe5..5e5efb79401 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -21,6 +21,7 @@ from mindspore import Tensor from mindspore import nn from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.export import quant_export +from mindspore.compression.quant.quantizer import OptimizeOption from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -99,3 +100,56 @@ def test_qat_mobile_per_channel_ff(): # should load the checkpoint. mock here network.init_parameters_data() quant_export.export(network, img, file_name="quant.pb") + + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_lsq_lenet(): + img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) + net = LeNet5() + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=OptimizeOption.LEARNED_SCALE) + net = quantizer.quantize(net) + # should load the checkpoint. mock here + net.init_parameters_data() + quant_export.export(net, img, file_name="quant.pb") + + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_lsq_mobile_per_channel_tf(): + network = mobilenetV2(num_classes=1000) + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=OptimizeOption.LEARNED_SCALE) + network = quantizer.quantize(network) + # should load the checkpoint. mock here + network.init_parameters_data() + quant_export.export(network, img, file_name="quant.pb") + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_lsq_mobile_per_channel_ff(): + network = mobilenetV2(num_classes=1000) + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[False, False], + symmetric=[True, True], + narrow_range=[True, True], + freeze_bn=0, + quant_delay=0, + one_conv_fold=True, + optimize_option=OptimizeOption.LEARNED_SCALE) + network = quantizer.quantize(network) + # should load the checkpoint. mock here + network.init_parameters_data() + quant_export.export(network, img, file_name="quant.pb")