From a834a6308eafd8f1877fe23cceb3d1e51c66c943 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Sat, 20 Jun 2020 14:27:37 +0800 Subject: [PATCH] change some comment name in the whole project --- build.sh | 2 +- mindspore/ccsrc/dataset/engine/perf/monitor.h | 2 +- mindspore/ccsrc/ir/meta_tensor.h | 2 +- mindspore/ccsrc/ir/tensor.h | 2 +- ..._impl.cu => fake_quant_perchannel_impl.cu} | 40 +----- ...mpl.cuh => fake_quant_perchannel_impl.cuh} | 0 ...nt_impl.cu => fake_quant_perlayer_impl.cu} | 2 +- ..._impl.cuh => fake_quant_perlayer_impl.cuh} | 0 .../gpu/cuda_impl/minmax_update_impl.cu | 104 +++++++++++++++ .../gpu/cuda_impl/minmax_update_impl.cuh | 30 +++++ ...cc => fake_quant_perchannel_gpu_kernel.cc} | 113 ++++++----------- ...l.h => fake_quant_perchannel_gpu_kernel.h} | 22 ++-- ... fake_quant_perchannel_grad_gpu_kernel.cc} | 54 ++++---- ...> fake_quant_perchannel_grad_gpu_kernel.h} | 6 +- ...l.cc => fake_quant_perlayer_gpu_kernel.cc} | 109 ++++++---------- ...nel.h => fake_quant_perlayer_gpu_kernel.h} | 22 ++-- ...=> fake_quant_perlayer_grad_gpu_kernel.cc} | 96 ++++++-------- ... => fake_quant_perlayer_grad_gpu_kernel.h} | 17 ++- .../minmax_update_perchannel_gpu_kernel.cc | 119 ++++++++++++++++++ .../minmax_update_perchannel_gpu_kernel.h | 60 +++++++++ .../minmax_update_perlayer_gpu_kernel.cc | 115 +++++++++++++++++ .../quant/minmax_update_perlayer_gpu_kernel.h | 59 +++++++++ mindspore/ccsrc/utils/lineage.proto | 2 +- mindspore/ccsrc/utils/summary.proto | 2 +- mindspore/ccsrc/vm/transform.h | 2 +- mindspore/nn/layer/quant.py | 52 +++++--- mindspore/ops/_grad/grad_quant_ops.py | 10 +- .../fake_quant_minmax_perchannel_update.py | 4 +- .../fake_quant_minmax_perlayer_update.py | 8 +- mindspore/ops/operations/_quant_ops.py | 24 ++-- mindspore/ops/operations/array_ops.py | 2 +- mindspore/train/callback/_loss_monitor.py | 4 +- mindspore/train/quant/__init__.py | 4 +- mindspore/train/quant/quant.py | 33 ++--- mindspore/train/summary/_summary_adapter.py | 2 +- model_zoo/lenet_quant/README.md | 30 ++--- model_zoo/lenet_quant/eval_quant.py | 4 +- model_zoo/lenet_quant/train_quant.py | 4 +- .../mindspore_test_framework/utils/keyword.py | 4 +- 39 files changed, 757 insertions(+), 410 deletions(-) rename mindspore/ccsrc/kernel/gpu/cuda_impl/{fake_quant_per_channel_impl.cu => fake_quant_perchannel_impl.cu} (73%) rename mindspore/ccsrc/kernel/gpu/cuda_impl/{fake_quant_per_channel_impl.cuh => fake_quant_perchannel_impl.cuh} (100%) rename mindspore/ccsrc/kernel/gpu/cuda_impl/{fake_quant_impl.cu => fake_quant_perlayer_impl.cu} (96%) rename mindspore/ccsrc/kernel/gpu/cuda_impl/{fake_quant_impl.cuh => fake_quant_perlayer_impl.cuh} (100%) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_per_channel_gpu_kernel.cc => fake_quant_perchannel_gpu_kernel.cc} (53%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_per_channel_gpu_kernel.h => fake_quant_perchannel_gpu_kernel.h} (75%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_per_channel_grad_gpu_kernel.cc => fake_quant_perchannel_grad_gpu_kernel.cc} (74%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_per_channel_grad_gpu_kernel.h => fake_quant_perchannel_grad_gpu_kernel.h} (91%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_gpu_kernel.cc => fake_quant_perlayer_gpu_kernel.cc} (50%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_gpu_kernel.h => fake_quant_perlayer_gpu_kernel.h} (77%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_grad_gpu_kernel.cc => fake_quant_perlayer_grad_gpu_kernel.cc} (51%) rename mindspore/ccsrc/kernel/gpu/quant/{fake_quant_grad_gpu_kernel.h => fake_quant_perlayer_grad_gpu_kernel.h} (77%) create mode 100644 mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h create mode 100644 mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h diff --git a/build.sh b/build.sh index dfed66aadf5..0060a7bbb3e 100755 --- a/build.sh +++ b/build.sh @@ -243,7 +243,7 @@ checkopts() done } checkopts "$@" -echo "---------------- mindspore: build start ----------------" +echo "---------------- MindSpore: build start ----------------" mkdir -pv "${BUILD_PATH}/package/mindspore/lib" git submodule update --init graphengine diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.h b/mindspore/ccsrc/dataset/engine/perf/monitor.h index 11b3149ede1..2a482a6ad71 100644 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.h +++ b/mindspore/ccsrc/dataset/engine/perf/monitor.h @@ -36,7 +36,7 @@ class Monitor { ~Monitor() = default; // Functor for Perf Monitor main loop. - // This function will be the entry point of Mindspore::Dataset::Task + // This function will be the entry point of mindspore::Dataset::Task Status operator()(); int64_t GetSamplingInterval() { return sampling_interval_; } diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index a85ef77e832..d78caf3b5dd 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -29,7 +29,7 @@ // brief mindspore namespace. // -// mindspore namespace is the top level namespace of Mindsporeession project. +// mindspore namespace is the top level namespace of MindSpore project. // Other namespace should be a sub namespace of mindspore namespace in the ME project. namespace mindspore { diff --git a/mindspore/ccsrc/ir/tensor.h b/mindspore/ccsrc/ir/tensor.h index 700dcd49102..9d368ab4272 100644 --- a/mindspore/ccsrc/ir/tensor.h +++ b/mindspore/ccsrc/ir/tensor.h @@ -90,7 +90,7 @@ using mindspore::device::DeviceAddress; using DeviceAddressPtr = std::shared_ptr; // brief mindspore namespace. // -// mindspore namespace is the top level namespace of Mindsporeession project. +// mindspore namespace is the top level namespace of MindSpore project. // Other namespace should be a sub namespace of mindspore namespace in the ME project. namespace mindspore { // brief mindspore::tensor namespace diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu similarity index 73% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu rename to mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu index b9aac9bdc38..75c5eacb25b 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu @@ -19,7 +19,7 @@ #include #include #include -#include "fake_quant_per_channel_impl.cuh" +#include "fake_quant_perchannel_impl.cuh" #include "device/gpu/cuda_common.h" /** @@ -113,44 +113,6 @@ void CalFakeQuantizePerChannel(const float *input, float *output, const int tota input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); } -/** - * UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA. - * @param input_min - * @param input_max - * @param min - * @param max - * @return - */ -__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels, - int per_channel_nums, bool ema, float ema_decay) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { - thrust::pair sum = - thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); - if (ema) { - input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; - input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; - } else { - input_min[i] = sum.first[0]; - input_max[i] = sum.second[0]; - } - input_min[i] = input_min[i] > 0 ? 0 : input_min[i]; - input_max[i] = input_max[i] < 0 ? 0 : input_max[i]; - } -} - -__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max, - const float decay) { - *input_min = decay * (min) + (1 - decay) * (*input_min); - *input_max = decay * (max) + (1 - decay) * (*input_max); -} - -void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size, - const float ema_decay, const bool ema, cudaStream_t cuda_stream) { - int per_channel_num = total_size / channel_size; - UpdateInputMinMaxPerChannel<<>>( - input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); -} - __global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_size, const int channel_size, const float *nudge_min, const float *nudge_max) { diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh rename to mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu similarity index 96% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu rename to mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu index db3f8a857f1..11a25ba2947 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu @@ -18,7 +18,7 @@ #include #include #include "device/gpu/cuda_common.h" -#include "fake_quant_impl.cuh" +#include "fake_quant_perlayer_impl.cuh" __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, const float *scale) { diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh rename to mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu new file mode 100644 index 00000000000..4d313fa9313 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "minmax_update_impl.cuh" +#include "device/gpu/cuda_common.h" + +__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, + float *output_max, const float min, const float max, const float decay, + const float symmetric) { + output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); + output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; + output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); + output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; + + if (symmetric) { + output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; + output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; + } + return; +} + +__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max, + const float symmetric) { + output_min[0] = min > 0 ? 0 : min; + output_max[0] = max < 0 ? 0 : max; + + if (symmetric) { + output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; + output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; + } + return; +} + +__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, + float *output_max, int channels, int per_channel_nums, bool ema, + float ema_decay, bool symmetric) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { + thrust::pair sum = + thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); + if (ema) { + output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; + output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; + } else { + output_min[i] = sum.first[0]; + output_max[i] = sum.second[0]; + } + output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; + output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; + + if (symmetric) { + output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i]; + output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i]; + } + } + return; +} + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + const bool symmetric, cudaStream_t cuda_stream) { + int per_channel_num = total_num / channel_num; + UpdateInputMinMaxPerChannel<<>>( + input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric); + return; +} + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const float ema_decay, const bool ema, const bool symmetric, + cudaStream_t cuda_stream) { + float minel = 0.f; + float maxel = 0.f; + auto policy = thrust::cuda::par.on(cuda_stream); + thrust::pair, thrust::device_ptr> tuple; + tuple = + thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); + minel = tuple.first[0]; + maxel = tuple.second[0]; + + if (ema) { + UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, + maxel, ema_decay, symmetric); + } else { + UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric); + } + return; +} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh new file mode 100644 index 00000000000..e0e5a731d3f --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ + +#include "device/gpu/cuda_common.h" + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + const bool symmetric, cudaStream_t cuda_stream); + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int size, const float ema_decay, const bool ema, const bool symmetric, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc similarity index 53% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc index ea1fea33227..ffed550fbbc 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" +#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" #include #include #include @@ -25,21 +25,15 @@ namespace mindspore { namespace kernel { FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() : input_size_(0), - min_size_(0), - max_size_(0), - output_size_(0), - workspace_size_(0), + num_channels_(0), num_bits_(0), + training_(false), + symmetric_(false), + narrow_range_(false), + quant_delay_(0), quant_min_(0), quant_max_(0), - quant_delay_(0), - ema_(false), - ema_decay_(0), - global_step_(0), - training_(false), - channel_out_(0), - narrow_range_(false), - symmetric_(false) {} + global_step_(0) {} const std::vector &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } @@ -60,90 +54,56 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { return false; } + // get attribute num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = 1.0 - GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); if (num_bits_ <= 2 || num_bits_ >= 16) { MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; return false; } - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); if (quant_delay_ < 0) { MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; return false; } - training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); - - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - if (symmetric_) { - quant_min_ = 0 - (1 << (num_bits_ - 1)); - quant_max_ = (1 << (num_bits_ - 1)) - 1; - } else { - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - } - - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; if (narrow_range_) { quant_min_++; } // shape info for gpu auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - channel_out_ = SizeToInt(input_shape[0]); - min_size_ = sizeof(float) * channel_out_; - max_size_ = sizeof(float) * channel_out_; + num_channels_ = SizeToInt(input_shape[0]); input_size_ = sizeof(float); for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - output_size_ = input_size_; - InitSizeLists(); return true; } void FakeQuantPerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input in tensor - input_size_list_.push_back(min_size_); // min one scalar - input_size_list_.push_back(max_size_); // max on scalar - output_size_list_.push_back(output_size_); // output in tensor - workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel - workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel + input_size_list_.push_back(input_size_); // input in tensor + input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar + input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar + output_size_list_.push_back(input_size_); // output in tensor + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel } -void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, - float *input_max, float *d_nudge_min, float *d_nudge_max, - float *d_scale, void *stream_ptr) { - // calculate the input min and max according by the parameter ema and ema_decay. - CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_, - reinterpret_cast(stream_ptr)); - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, - reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, - d_scale, symmetric_, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR( - cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; -} - -void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min, - float *input_max, float *d_nudge_min, float *d_nudge_max, - float *d_scale, void *stream_ptr) { - // real launch - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, +void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, + float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale, + CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, symmetric_, reinterpret_cast(stream_ptr)); } @@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, float *input = GetDeviceAddress(inputs, 0); float *input_min = GetDeviceAddress(inputs, 1); float *input_max = GetDeviceAddress(inputs, 2); - float *d_scale = GetDeviceAddress(workspace, 0); - float *d_nudge_min = GetDeviceAddress(workspace, 1); - float *d_nudge_max = GetDeviceAddress(workspace, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); if (input == nullptr) { MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; @@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, } if (training_) { - CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); + if (global_step_ >= quant_delay_) { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; } else { - CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); } return true; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h similarity index 75% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h index bea1a7421fc..122fe96af32 100755 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h @@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel { void InitSizeLists() override; private: - void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, - float *d_nudge_max, float *d_scale, void *stream_ptr); - void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min, - float *d_nudge_max, float *d_scale, void *stream_ptr); + void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, + float *nudge_max, float *scale, void *stream_ptr); size_t input_size_; - size_t min_size_; - size_t max_size_; - size_t output_size_; - size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; + int num_channels_; int num_bits_; + bool training_; + bool symmetric_; + bool narrow_range_; + int quant_delay_; float quant_min_; float quant_max_; - int quant_delay_; - bool ema_; - float ema_decay_; int global_step_; - bool training_; - int channel_out_; - bool narrow_range_; - bool symmetric_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc similarity index 74% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc index b43e178eb1d..a57516eb2c7 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc @@ -14,21 +14,17 @@ * limitations under the License. */ -#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh" +#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" namespace mindspore { namespace kernel { FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() : input_size_(0), - min_size_(0), - max_size_(0), - output_size_(0), - workspace_size_(0), num_bits_(0), quant_min_(0), quant_max_(0), - channel_out_(0), + num_channels_(0), quant_delay_(0), global_step_(0), narrow_range_(false), @@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { } symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - if (symmetric_) { - quant_min_ = 0 - (1 << (num_bits_ - 1)); - quant_max_ = (1 << (num_bits_ - 1)) - 1; - } else { - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - } - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; if (narrow_range_) { quant_min_++; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - channel_out_ = SizeToInt(input_shape[0]); - min_size_ = sizeof(float) * channel_out_; - max_size_ = sizeof(float) * channel_out_; + num_channels_ = SizeToInt(input_shape[0]); input_size_ = sizeof(float); for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - output_size_ = input_size_; - InitSizeLists(); return true; } void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(min_size_); // min - input_size_list_.push_back(max_size_); // max - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel - workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel } bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, @@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp float *input = GetDeviceAddress(inputs, 1); float *input_min = GetDeviceAddress(inputs, 2); float *input_max = GetDeviceAddress(inputs, 3); - float *d_scale = GetDeviceAddress(workspace, 0); - float *d_nudge_min = GetDeviceAddress(workspace, 1); - float *d_nudge_max = GetDeviceAddress(workspace, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); if (gradient == nullptr) { MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; @@ -130,9 +118,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp int total_size = input_size_ / sizeof(float); if (global_step_ >= quant_delay_) { - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, + CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, reinterpret_cast(stream_ptr)); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h similarity index 91% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h index fe760d85d24..d863a2c99f5 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h @@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel { private: size_t input_size_; - size_t min_size_; - size_t max_size_; - size_t output_size_; - size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; @@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel { int num_bits_; float quant_min_; float quant_max_; - int channel_out_; + int num_channels_; int quant_delay_; int global_step_; bool narrow_range_; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc similarity index 50% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc index 31f37bd7333..845fb5b923d 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "kernel/gpu/quant/fake_quant_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" +#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" #include #include #include @@ -23,31 +23,25 @@ namespace mindspore { namespace kernel { -FakeQuantGpuKernel::FakeQuantGpuKernel() +FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() : input_size_(0), - min_size_(0), - max_size_(0), - output_size_(0), - workspace_size_(0), - num_bits_(0), quant_min_(0), quant_max_(0), - quant_num_(0), - quant_delay_(0), - ema_(false), - ema_decay_(0), + quant_num_(1), global_step_(0), + num_bits_(0), + quant_delay_(0), training_(false), narrow_range_(false), symmetric_(false) {} -const std::vector &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; } +const std::vector &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } -const std::vector &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; } +const std::vector &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } -const std::vector &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } +const std::vector &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } -bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { +bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 3) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; @@ -59,95 +53,73 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { } num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); if (num_bits_ <= 2 || num_bits_ >= 16) { MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; } - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); if (quant_delay_ < 0) { MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; } - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - if (symmetric_) { - quant_min_ = 0 - (1 << (num_bits_ - 1)); - quant_max_ = (1 << (num_bits_ - 1)) - 1; - } else { - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - } - - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; if (narrow_range_) { quant_min_++; } - if (quant_num_ == 0) { - quant_num_ = 1; - } + // init size auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (size_t i = 0; i < input_shape.size(); ++i) { quant_num_ *= SizeToInt(input_shape[i]); } - input_size_ = sizeof(float); - min_size_ = sizeof(float); - max_size_ = sizeof(float); for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - output_size_ = input_size_; InitSizeLists(); return true; } -void FakeQuantGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(min_size_); // min - input_size_list_.push_back(max_size_); // max - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); +void FakeQuantPerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max } -bool FakeQuantGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { +bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { float *output = GetDeviceAddress(outputs, 0); float *input = GetDeviceAddress(inputs, 0); float *input_min = GetDeviceAddress(inputs, 1); float *input_max = GetDeviceAddress(inputs, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null."; + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; } - if (input_min == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null."; + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; } - if (input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null."; - } - - // Allocate space for device copies - int size = sizeof(float); - float *d_scale = nullptr; - float *d_nudge_min = nullptr; - float *d_nudge_max = nullptr; - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_scale), size), "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_min), size), "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_max), size), "Malloc gpu memory failed"); if (training_) { - // calculate the input min and max according by the parameter ema and ema_decay. - CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast(stream_ptr)); // control flow for quant_delay if (global_step_ >= quant_delay_) { // real launch - CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, + CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, reinterpret_cast(stream_ptr)); - CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, + CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, reinterpret_cast(stream_ptr)); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, @@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector &inputs, const std global_step_++; } else { // real launch - CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, + CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, reinterpret_cast(stream_ptr)); - CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, + CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, reinterpret_cast(stream_ptr)); } - // Cleanup - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); - return true; } -MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h similarity index 77% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.h rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h index 5a594c615f4..38810e06dfb 100755 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ #include #include "kernel/gpu/gpu_kernel.h" @@ -23,10 +23,10 @@ namespace mindspore { namespace kernel { -class FakeQuantGpuKernel : public GpuKernel { +class FakeQuantPerLayerGpuKernel : public GpuKernel { public: - FakeQuantGpuKernel(); - ~FakeQuantGpuKernel() = default; + FakeQuantPerLayerGpuKernel(); + ~FakeQuantPerLayerGpuKernel() = default; const std::vector &GetInputSizeList() const override; const std::vector &GetOutputSizeList() const override; @@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel { private: size_t input_size_; - size_t min_size_; - size_t max_size_; - size_t output_size_; - size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - int num_bits_; float quant_min_; float quant_max_; int quant_num_; - int quant_delay_; - bool ema_; - float ema_decay_; int global_step_; + int num_bits_; + int quant_delay_; bool training_; bool narrow_range_; bool symmetric_; @@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc similarity index 51% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc index b9dcf6c6c31..9c6584e2396 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc @@ -14,33 +14,30 @@ * limitations under the License. */ -#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh" +#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" namespace mindspore { namespace kernel { -FakeQuantGradGpuKernel::FakeQuantGradGpuKernel() +FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() : input_size_(0), - min_size_(0), - max_size_(0), - output_size_(0), workspace_size_(0), num_bits_(0), quant_min_(0), quant_max_(0), - quant_size_(0), + quant_num_(1), quant_delay_(0), global_step_(0), narrow_range_(false), symmetric_(false) {} -const std::vector &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; } +const std::vector &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } -const std::vector &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } +const std::vector &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } -const std::vector &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } +const std::vector &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } -bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { +bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 4) { MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; @@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) { } symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - if (symmetric_) { - quant_min_ = 0 - (1 << (num_bits_ - 1)); - quant_max_ = (1 << (num_bits_ - 1)) - 1; - } else { - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - } - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; if (narrow_range_) { quant_min_++; } - if (quant_size_ == 0) { - quant_size_ = 1; - } + // init size auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (size_t i = 0; i < input_shape.size(); ++i) { - quant_size_ *= SizeToInt(input_shape[i]); + quant_num_ *= SizeToInt(input_shape[i]); } - input_size_ = sizeof(float); - min_size_ = sizeof(float); - max_size_ = sizeof(float); for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - output_size_ = input_size_; - InitSizeLists(); return true; } -void FakeQuantGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(min_size_); // min - input_size_list_.push_back(max_size_); // max - output_size_list_.push_back(output_size_); +void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max } -bool FakeQuantGradGpuKernel::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { +bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { float *output = GetDeviceAddress(outputs, 0); float *gradient = GetDeviceAddress(inputs, 0); float *input = GetDeviceAddress(inputs, 1); float *input_min = GetDeviceAddress(inputs, 2); float *input_max = GetDeviceAddress(inputs, 3); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); if (gradient == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null"; + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; } if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null."; + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; } - if (input_min == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null."; - } - if (input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null."; + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; } if (global_step_ >= quant_delay_) { - float *d_scale = nullptr; - float *d_nudge_min = nullptr; - float *d_nudge_max = nullptr; - int size = sizeof(float); - // Allocate space for device copies - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_scale), size), "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_min), size), "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_max), size), "Malloc gpu memory failed"); - - CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, + CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, reinterpret_cast(stream_ptr)); - CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max, + CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, reinterpret_cast(stream_ptr)); - - // Cleanup - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), @@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector &inputs, const return true; } -MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel) +MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h similarity index 77% rename from mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h rename to mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h index cfde98355c6..ae2ea5bfacc 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ #include #include "kernel/gpu/gpu_kernel.h" @@ -23,10 +23,10 @@ namespace mindspore { namespace kernel { -class FakeQuantGradGpuKernel : public GpuKernel { +class FakeQuantPerLayerGradGpuKernel : public GpuKernel { public: - FakeQuantGradGpuKernel(); - ~FakeQuantGradGpuKernel() = default; + FakeQuantPerLayerGradGpuKernel(); + ~FakeQuantPerLayerGradGpuKernel() = default; const std::vector &GetInputSizeList() const override; const std::vector &GetOutputSizeList() const override; @@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel { private: size_t input_size_; - size_t min_size_; - size_t max_size_; - size_t output_size_; size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; @@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel { int num_bits_; float quant_min_; float quant_max_; - int quant_size_; + int quant_num_; int quant_delay_; int global_step_; bool narrow_range_; @@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc new file mode 100644 index 00000000000..ae9df3355bb --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() + : input_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + ema_(false), + ema_decay_(0), + num_channels_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + // quant min and max + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(sizeof(float) * num_channels_); // output min + output_size_list_.push_back(sizeof(float) * num_channels_); // output max +} + +bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; + } + + // calculate the input min and max according by the parameter ema and ema_decay. + CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, + ema_decay_, ema_, symmetric_, reinterpret_cast(stream_ptr)); + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h new file mode 100644 index 00000000000..5a35d4da320 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerChannelGpuKernel(); + ~MinMaxUpdatePerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int quant_num_; + bool ema_; + float ema_decay_; + int num_channels_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc new file mode 100644 index 00000000000..8ba1d363dd6 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() + : input_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + ema_(false), + ema_decay_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + // quant min and max + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // input min + input_size_list_.push_back(sizeof(float)); // input max + output_size_list_.push_back(sizeof(float)); // output min + output_size_list_.push_back(sizeof(float)); // output max +} + +bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; + } + + CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_, + reinterpret_cast(stream_ptr)); + + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h new file mode 100644 index 00000000000..527de20427f --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerLayerGpuKernel(); + ~MinMaxUpdatePerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int quant_num_; + bool ema_; + float ema_decay_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/utils/lineage.proto b/mindspore/ccsrc/utils/lineage.proto index 510e58fc553..dec6f9a3f6e 100644 --- a/mindspore/ccsrc/utils/lineage.proto +++ b/mindspore/ccsrc/utils/lineage.proto @@ -28,7 +28,7 @@ message LineageEvent { oneof what { // An event file was started, with the specified version. - // Now version is "Mindspore.Event:1" + // Now version is "MindSpore.Event:1" string version = 3; // Train lineage diff --git a/mindspore/ccsrc/utils/summary.proto b/mindspore/ccsrc/utils/summary.proto index 6ea6ce08b83..f4a2ce957b1 100644 --- a/mindspore/ccsrc/utils/summary.proto +++ b/mindspore/ccsrc/utils/summary.proto @@ -32,7 +32,7 @@ message Event { oneof what { // An event file was started, with the specified version. - // Now version is "Mindspore.Event:1" + // Now version is "MindSpore.Event:1" string version = 3; // GraphDef. diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 3a1da0ff429..ebe49908b7b 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -32,7 +32,7 @@ #include "vm/segment_runner.h" #include "vm/backend.h" -// mindspore namespace is the top level namespace of Mindsporeession project. +// mindspore namespace is the top level namespace of MindSpore project. // Other namespace should be a sub namespace of mindspore namespace in the ME project. namespace mindspore { extern const char kMsVm[]; diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 6c01aa54044..8c444362b25 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Aware quantization.""" +"""Quantization aware.""" from functools import partial import numpy as np @@ -172,7 +172,7 @@ class DenseBnAct(Cell): Tensor of shape :math:`(N, out\_channels)`. Examples: - >>> net = nn.Dense(3, 4) + >>> net = nn.DenseBnAct(3, 4) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> net(input) """ @@ -271,7 +271,7 @@ class BatchNormFoldCell(Cell): class FakeQuantWithMinMax(Cell): r""" - Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. + Quantization aware op. This OP provide Fake quantization observer function on data with min and max. Args: min_init (int, float): The dimension of channel or 1(layer). Default: -6. @@ -338,22 +338,30 @@ class FakeQuantWithMinMax(Cell): # init fake quant relative op if per_channel: quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) - ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) + ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) else: quant_fun = Q.FakeQuantPerLayer - ema_fun = Q.FakeQuantMinMaxPerLayerUpdate + ema_fun = Q.MinMaxUpdatePerLayer if self.is_ascend: self.fake_quant = quant_fun(num_bits=self.num_bits, symmetric=self.symmetric, narrow_range=self.narrow_range) else: - self.fake_quant = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range) + self.fake_quant_train = quant_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=True) + self.fake_quant_infer = quant_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=False) self.ema_update = ema_fun(num_bits=self.num_bits, ema=self.ema, ema_decay=self.ema_decay, @@ -368,16 +376,24 @@ class FakeQuantWithMinMax(Cell): return s def construct(self, x): - if self.is_ascend and self.training: - min_up, max_up = self.ema_update(x, self.minq, self.maxq) - out = self.fake_quant(x, min_up, max_up) - P.Assign()(self.minq, min_up) - P.Assign()(self.maxq, max_up) + if self.is_ascend: + if self.training: + min_up, max_up = self.ema_update(x, self.minq, self.maxq) + out = self.fake_quant(x, min_up, max_up) + P.Assign()(self.minq, min_up) + P.Assign()(self.maxq, max_up) + else: + out = self.fake_quant(x, self.minq, self.maxq) else: - out = self.fake_quant(x, self.minq, self.maxq) + if self.training: + min_up, max_up = self.ema_update(x, self.minq, self.maxq) + out = self.fake_quant_train(x, min_up, max_up) + P.Assign()(self.minq, min_up) + P.Assign()(self.maxq, max_up) + else: + out = self.fake_quant_infer(x, self.minq, self.maxq) return out - class Conv2dBatchNormQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index da19662e979..a2b0ba8d97e 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""Generate bprop for aware quantization ops""" +"""Generate bprop for quantization aware ops""" from .. import operations as P from ..operations import _quant_ops as Q @@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerLayer) def get_bprop_fakequant_with_minmax_per_layer_update(self): - """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerLayer for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) @@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self): return bprop -@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) +@bprop_getters.register(Q.MinMaxUpdatePerChannel) def get_bprop_fakequant_with_minmax_per_channel_update(self): - """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" + """Generate bprop for MinMaxUpdatePerChannel for Ascend""" def bprop(x, x_min, x_max, out, dout): return zeros_like(x), zeros_like(x_min), zeros_like(x_max) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py index 7694753d8f5..3560c20d7b0 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerChannelUpdate op""" +"""MinMaxUpdatePerChannel op""" import te.lang.cce from te import tvm from te.platform.fusion_manager import fusion_manager @@ -23,7 +23,7 @@ from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ +fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \ .fusion_type("OPAQUE") \ .async_flag(False) \ .binfile_name("fake_quant_min_max_per_channel_update.so") \ diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py index 0ad2315bb3f..554b66b95ba 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ -"""FakeQuantMinMaxPerLayerUpdate op""" +"""MinMaxUpdatePerLayer op""" from functools import reduce as functools_reduce import te.lang.cce from te import tvm @@ -23,7 +23,7 @@ from topi.cce import util from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ +fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ .fusion_type("OPAQUE") \ .async_flag(False) \ .binfile_name("fake_quant_minmax_update.so") \ @@ -48,14 +48,14 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ @op_info_register(fake_quant_minmax_update_op_info) def _fake_quant_minmax_update_tbe(): - """FakeQuantMinMaxPerLayerUpdate TBE register""" + """MinMaxUpdatePerLayer TBE register""" return @fusion_manager.register("fake_quant_minmax_update") def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, kernel_name="fake_quant_minmax_update"): - """FakeQuantMinMaxPerLayerUpdate compute""" + """MinMaxUpdatePerLayer compute""" shape = te.lang.cce.util.shape_to_list(x.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index dd46fa491a3..b668fd0d83d 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -25,8 +25,8 @@ __all__ = ["FakeQuantPerLayer", "FakeQuantPerLayerGrad", "FakeQuantPerChannel", "FakeQuantPerChannelGrad", - "FakeQuantMinMaxPerLayerUpdate", - "FakeQuantMinMaxPerChannelUpdate", + "MinMaxUpdatePerLayer", + "MinMaxUpdatePerChannel", "BatchNormFold", "BatchNormFoldGrad", "CorrectionMul", @@ -47,11 +47,11 @@ class FakeQuantPerLayer(PrimitiveWithInfer): Simulate the quantize and dequantize operations in training time. Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. + num_bits (int) : Number bits for quantization aware. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. quant_delay (int): Quantilization delay parameter. Before delay step in training time not update - simulate aware quantize funcion. After delay step in training time begin simulate the aware + simulate quantization aware funcion. After delay step in training time begin simulate the aware quantize funcion. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -834,12 +834,12 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): return dout_type, dout_type -class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): +class MinMaxUpdatePerLayer(PrimitiveWithInfer): r""" Update min and max value for fake quant per layer op. Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. + num_bits (int) : Number bits for quantization aware. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. @@ -858,14 +858,14 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) >>> min_tensor = Tensor(np.array([-6]), mstype.float32) >>> max_tensor = Tensor(np.array([6]), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) + >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) """ support_quant_bit = [4, 7, 8] @prim_attr_register def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, training=True): - """init FakeQuantMinMaxPerLayerUpdate OP""" + """init MinMaxUpdatePerLayer OP""" if context.get_context('device_target') == "Ascend": from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update if num_bits not in self.support_quant_bit: @@ -907,12 +907,12 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): return min_type, max_type -class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): +class MinMaxUpdatePerChannel(PrimitiveWithInfer): r""" Update min and max value for fake quant per layer op. Args: - num_bits (int) : Number bits for aware quantilization. Default: 8. + num_bits (int) : Number bits for quantization aware. Default: 8. ema (bool): Use EMA algorithm update value min and max. Default: False. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. @@ -932,14 +932,14 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) - >>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max) + >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) """ support_quant_bit = [4, 7, 8] @prim_attr_register def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, training=True, channel_axis=1): - """init FakeQuantPerChannelUpdate OP for Ascend""" + """init MinMaxUpdatePerChannel OP for Ascend""" if context.get_context('device_target') == "Ascend": from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update if num_bits not in self.support_quant_bit: diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2bb8a17a504..0cf725ecaaf 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1914,7 +1914,7 @@ class Eye(PrimitiveWithInfer): Inputs: - **n** (int) - Number of rows of returned tensor - **m** (int) - Number of columns of returned tensor - - **t** (mindspore.dtype) - Mindspore's dtype, The data type of the returned tensor. + - **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor. Outputs: Tensor, a tensor with ones on the diagonal and zeros elsewhere. diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py index 22b1342873b..3c1da218c21 100644 --- a/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/train/callback/_loss_monitor.py @@ -76,7 +76,7 @@ class LossMonitor(Callback): step_loss = np.mean(step_loss.asnumpy()) self.losses.append(step_loss) - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " @@ -88,6 +88,6 @@ class LossMonitor(Callback): print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( cb_params.cur_epoch_num - 1, cb_params.epoch_num, - cur_step_in_epoch, cb_params.batch_num, + cur_step_in_epoch, int(cb_params.batch_num), step_loss, np.mean(self.losses), step_mseconds), flush=True) diff --git a/mindspore/train/quant/__init__.py b/mindspore/train/quant/__init__.py index 531db34b2b7..51e8c20ded0 100644 --- a/mindspore/train/quant/__init__.py +++ b/mindspore/train/quant/__init__.py @@ -15,10 +15,10 @@ """ quantization. -User can use aware quantization to train a model. Mindspore supports quantization aware training, +User can use quantization aware to train a model. MindSpore supports quantization aware training, which models quantization errors in both the forward and backward passes using fake-quantization ops. Note that the entire computation is carried out in floating point. At the end of quantization -aware training, Mindspore provides conversion functions to convert the trained model into lower precision. +aware training, MindSpore provides conversion functions to convert the trained model into lower precision. """ from .quant import convert_quant_network diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index a8f381425cf..937e54a7e48 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""aware quantization.""" +"""quantization aware.""" import copy import re import numpy as np +import mindspore.context as context from ... import log as logger from ... import nn, ops @@ -234,7 +235,7 @@ class ConvertToQuantNetwork: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, - quant_delay=self.act_delay, + quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) @@ -403,29 +404,30 @@ def convert_quant_network(network, narrow_range=(False, False) ): r""" - Create aware quantizaiton training network. + Create quantization aware training network. Args: network (Cell): Obtain a pipeline through network for saving graph summary. - quant_delay (int): Number of steps after which weights and activations are quantized during - eval. The first element represent weights and second element represent data flow. Default: [0, 0] + quant_delay (int or tuple): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: (0, 0) bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. - num_bits (list of int): Number of bits to use for quantizing weights and activations. The first - element represent weights and second element represent data flow. Default: [8, 8] - per_channel (list of bool): Quantization granularity based on layer or on channel. If `True` + num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first + element represent weights and second element represent data flow. Default: (8, 8) + per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights - and second element represent data flow. Default: [False, False] - symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on + and second element represent data flow. Default: (False, False) + symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on symmetric otherwise base on assymmetric. The first element represent weights and second - element represent data flow. Default: [False, False] - narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base + element represent data flow. Default: (False, False) + narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base on narrow range otherwise base on off narrow range. The first element represent weights and - second element represent data flow. Default: [False, False] + second element represent data flow. Default: (False, False) Returns: - Cell, Network which has change to aware quantization training network cell. + Cell, Network which has change to quantization aware training network cell. """ + support_device = ["Ascend", "GPU"] def convert2list(name, value): if not isinstance(value, list) and not isinstance(value, tuple): value = [value] @@ -439,6 +441,9 @@ def convert_quant_network(network, symmetric = convert2list("symmetric", symmetric) narrow_range = convert2list("narrow range", narrow_range) + if context.get_context('device_target') not in support_device: + raise KeyError("Not support {} backend.".format(context.get_context('device_target'))) + net = ConvertToQuantNetwork(network=network, quant_delay=quant_delay, bn_fold=bn_fold, diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index 40e32b1c6ad..1ae5bdd2d59 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -30,7 +30,7 @@ MS_IMAGE_TENSOR_FORMAT = 'NCHW' # Set the Event mark EVENT_FILE_NAME_MARK = ".out.events.summary." # Set the init event of version and mark -EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" +EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:" EVENT_FILE_INIT_VERSION = 1 F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/lenet_quant/README.md index b3bac22c0de..26cdcc3ecd6 100644 --- a/model_zoo/lenet_quant/README.md +++ b/model_zoo/lenet_quant/README.md @@ -2,13 +2,13 @@ ## Description -Training LeNet with MNIST dataset in MindSpore with quantization aware trainging. +Training LeNet with MNIST dataset in MindSpore with quantization aware training. This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. In this tutorial, you will: -1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. +1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend. 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples. @@ -24,10 +24,10 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http ```python pip uninstall -y mindspore-ascend pip uninstall -y mindspore-gpu -pip install mindspore-ascend-0.4.0.whl +pip install mindspore-ascend.whl ``` -then you will get the following display +Then you will get the following display ```bash @@ -87,7 +87,7 @@ class LeNet5(nn.Cell): return x ``` -get the MNIST from scratch dataset. +Get the MNIST from scratch dataset. ```Python ds_train = create_dataset(os.path.join(args.data_path, "train"), @@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size() ### Train model -Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. +Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. ```Python # Define the network @@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following: >>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] ``` -To save your time, just run this command. +Also, you can just run this command instead. ```python python train.py --data_path MNIST_Data --device_target Ascend @@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the # define funsion network network = LeNet5Fusion(cfg.num_classes) -# load aware quantizaiton network checkpoint +# load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) -# convert funsion netwrok to aware quantizaiton network +# convert funsion netwrok to quantization aware network network = quant.convert_quant_network(network) ``` ### load checkpoint -after convert to quantization aware network, we can load the checkpoint file. +After convert to quantization aware network, we can load the checkpoint file. ```python config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, @@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) ### train quantization aware model -To save your time, just run this command. +Also, you can just run this command instread. ```python python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt @@ -210,18 +210,18 @@ Procedure of quantization aware model evaluation is different from normal. Becau # define funsion network network = LeNet5Fusion(cfg.num_classes) -# load aware quantizaiton network checkpoint +# load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) -# convert funsion netwrok to aware quantizaiton network +# convert funsion netwrok to quantization aware network network = quant.convert_quant_network(network ``` -To save your time, just run this command. +Also, you can just run this command insread. ```python -python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt +python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt ``` The top1 accuracy would display on shell. diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/lenet_quant/eval_quant.py index 0ff943f8cdc..492f6d36b2c 100644 --- a/model_zoo/lenet_quant/eval_quant.py +++ b/model_zoo/lenet_quant/eval_quant.py @@ -50,7 +50,7 @@ if __name__ == "__main__": # define funsion network network = LeNet5Fusion(cfg.num_classes) - # convert funsion netwrok to aware quantizaiton network + # convert funsion netwrok to quantization aware network network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") @@ -60,7 +60,7 @@ if __name__ == "__main__": ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - # load aware quantizaiton network checkpoint + # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index 3de700af78c..04f595f322c 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -50,10 +50,10 @@ if __name__ == "__main__": # define funsion network network = LeNet5Fusion(cfg.num_classes) - # load aware quantizaiton network checkpoint + # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) - # convert funsion netwrok to aware quantizaiton network + # convert funsion netwrok to quantization aware network network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/tests/mindspore_test_framework/utils/keyword.py b/tests/mindspore_test_framework/utils/keyword.py index 56c27b0d044..cee0f14ff83 100644 --- a/tests/mindspore_test_framework/utils/keyword.py +++ b/tests/mindspore_test_framework/utils/keyword.py @@ -18,14 +18,14 @@ import sys -class _MindsporeTestFrameworkkeyword: +class _MindSporeTestFrameworkkeyword: def __setattr__(self, name, value): if name in self.__dict__: raise TypeError("can not rebind keyword (%s)" % name) self.__dict__[name] = value -keyword = _MindsporeTestFrameworkkeyword() +keyword = _MindSporeTestFrameworkkeyword() keyword.function = "function" keyword.inputs = "inputs"