!2405 change some comment name in the whole project

Merge pull request !2405 from chenzhongming/master
This commit is contained in:
mindspore-ci-bot 2020-06-22 14:40:25 +08:00 committed by Gitee
commit f1a69de0b6
39 changed files with 757 additions and 410 deletions

View File

@ -248,7 +248,7 @@ checkopts()
done
}
checkopts "$@"
echo "---------------- mindspore: build start ----------------"
echo "---------------- MindSpore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then

View File

@ -36,7 +36,7 @@ class Monitor {
~Monitor() = default;
// Functor for Perf Monitor main loop.
// This function will be the entry point of Mindspore::Dataset::Task
// This function will be the entry point of mindspore::Dataset::Task
Status operator()();
int64_t GetSamplingInterval() { return sampling_interval_; }

View File

@ -29,7 +29,7 @@
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {

View File

@ -91,7 +91,7 @@ using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// brief mindspore::tensor namespace

View File

@ -19,7 +19,7 @@
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_per_channel_impl.cuh"
#include "fake_quant_perchannel_impl.cuh"
#include "device/gpu/cuda_common.h"
/**
@ -113,44 +113,6 @@ void CalFakeQuantizePerChannel(const float *input, float *output, const int tota
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
}
/**
* UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA.
* @param input_min
* @param input_max
* @param min
* @param max
* @return
*/
__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels,
int per_channel_nums, bool ema, float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
input_min[i] = sum.first[0];
input_max[i] = sum.second[0];
}
input_min[i] = input_min[i] > 0 ? 0 : input_min[i];
input_max[i] = input_max[i] < 0 ? 0 : input_max[i];
}
}
__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max,
const float decay) {
*input_min = decay * (min) + (1 - decay) * (*input_min);
*input_max = decay * (max) + (1 - decay) * (*input_max);
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size,
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
int per_channel_num = total_size / channel_size;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
}
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
const int total_size, const int channel_size, const float *nudge_min,
const float *nudge_max) {

View File

@ -18,7 +18,7 @@
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "device/gpu/cuda_common.h"
#include "fake_quant_impl.cuh"
#include "fake_quant_perlayer_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {

View File

@ -0,0 +1,104 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "minmax_update_impl.cuh"
#include "device/gpu/cuda_common.h"
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
float *output_max, const float min, const float max, const float decay,
const float symmetric) {
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max,
const float symmetric) {
output_min[0] = min > 0 ? 0 : min;
output_max[0] = max < 0 ? 0 : max;
if (symmetric) {
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
}
return;
}
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
float *output_max, int channels, int per_channel_nums, bool ema,
float ema_decay, bool symmetric) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
output_min[i] = sum.first[0];
output_max[i] = sum.second[0];
}
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
if (symmetric) {
output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i];
output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i];
}
}
return;
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream) {
int per_channel_num = total_num / channel_num;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric);
return;
}
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple =
thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
maxel, ema_decay, symmetric);
} else {
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric);
}
return;
}

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#include "device/gpu/cuda_common.h"
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
const bool symmetric, cudaStream_t cuda_stream);
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int size, const float ema_decay, const bool ema, const bool symmetric,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
@ -25,21 +25,15 @@ namespace mindspore {
namespace kernel {
FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_channels_(0),
num_bits_(0),
training_(false),
symmetric_(false),
narrow_range_(false),
quant_delay_(0),
quant_min_(0),
quant_max_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
global_step_(0),
training_(false),
channel_out_(0),
narrow_range_(false),
symmetric_(false) {}
global_step_(0) {}
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
@ -60,90 +54,56 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
return false;
}
// get attribute
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
return false;
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
return false;
}
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// shape info for gpu
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(min_size_); // min one scalar
input_size_list_.push_back(max_size_); // max on scalar
output_size_list_.push_back(output_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar
input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar
output_size_list_.push_back(input_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(
cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
}
if (training_) {
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
if (global_step_ >= quant_delay_) {
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
} else {
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
}
return true;

View File

@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void InitSizeLists() override;
private:
void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min,
float *nudge_max, float *scale, void *stream_ptr);
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_channels_;
int num_bits_;
bool training_;
bool symmetric_;
bool narrow_range_;
int quant_delay_;
float quant_min_;
float quant_max_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
bool training_;
int channel_out_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -14,21 +14,17 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
channel_out_(0),
num_channels_(0),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
@ -130,9 +118,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,

View File

@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int channel_out_;
int num_channels_;
int quant_delay_;
int global_step_;
bool narrow_range_;

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
@ -23,31 +23,25 @@
namespace mindspore {
namespace kernel {
FakeQuantGpuKernel::FakeQuantGpuKernel()
FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
quant_num_(1),
global_step_(0),
num_bits_(0),
quant_delay_(0),
training_(false),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
@ -59,95 +53,73 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_num_ == 0) {
quant_num_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
void FakeQuantPerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // x
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // y
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null.";
}
// Allocate space for device copies
int size = sizeof(float);
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
global_step_++;
} else {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGpuKernel : public GpuKernel {
class FakeQuantPerLayerGpuKernel : public GpuKernel {
public:
FakeQuantGpuKernel();
~FakeQuantGpuKernel() = default;
FakeQuantPerLayerGpuKernel();
~FakeQuantPerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
int num_bits_;
int quant_delay_;
bool training_;
bool narrow_range_;
bool symmetric_;
@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_

View File

@ -14,33 +14,30 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_size_(0),
quant_num_(1),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_size_ == 0) {
quant_size_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_size_ *= SizeToInt(input_shape[i]);
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
void FakeQuantPerLayerGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *gradient = GetDeviceAddress<float>(inputs, 0);
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null";
}
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null.";
}
if (global_step_ >= quant_delay_) {
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
int size = sizeof(float);
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max,
CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGradGpuKernel : public GpuKernel {
class FakeQuantPerLayerGradGpuKernel : public GpuKernel {
public:
FakeQuantGradGpuKernel();
~FakeQuantGradGpuKernel() = default;
FakeQuantPerLayerGradGpuKernel();
~FakeQuantPerLayerGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int quant_size_;
int quant_num_;
int quant_delay_;
int global_step_;
bool narrow_range_;
@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_

View File

@ -0,0 +1,119 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
num_channels_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const {
return workspace_size_list_;
}
bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
num_channels_ = SizeToInt(input_shape[0]);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(sizeof(float) * num_channels_); // output min
output_size_list_.push_back(sizeof(float) * num_channels_); // output max
}
bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null.";
}
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerChannelGpuKernel();
~MinMaxUpdatePerChannelGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
int num_channels_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_

View File

@ -0,0 +1,115 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
: input_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(1),
ema_(false),
ema_decay_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
// quant min and max
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // input min
input_size_list_.push_back(sizeof(float)); // input max
output_size_list_.push_back(sizeof(float)); // output min
output_size_list_.push_back(sizeof(float)); // output max
}
bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
}
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerLayerGpuKernel();
~MinMaxUpdatePerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
bool ema_;
float ema_decay_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_

View File

@ -28,7 +28,7 @@ message LineageEvent {
oneof what {
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
// Now version is "MindSpore.Event:1"
string version = 3;
// Train lineage

View File

@ -32,7 +32,7 @@ message Event {
oneof what {
// An event file was started, with the specified version.
// Now version is "Mindspore.Event:1"
// Now version is "MindSpore.Event:1"
string version = 3;
// GraphDef.

View File

@ -32,7 +32,7 @@
#include "vm/segment_runner.h"
#include "vm/backend.h"
// mindspore namespace is the top level namespace of Mindsporeession project.
// mindspore namespace is the top level namespace of MindSpore project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
extern const char kMsVm[];

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Aware quantization."""
"""Quantization aware."""
from functools import partial
import numpy as np
@ -172,7 +172,7 @@ class DenseBnAct(Cell):
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
@ -271,7 +271,7 @@ class BatchNormFoldCell(Cell):
class FakeQuantWithMinMax(Cell):
r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
@ -338,22 +338,30 @@ class FakeQuantWithMinMax(Cell):
# init fake quant relative op
if per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate
ema_fun = Q.MinMaxUpdatePerLayer
if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
else:
self.fake_quant = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=True)
self.fake_quant_infer = quant_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=ema_decay,
quant_delay=quant_delay,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=False)
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
@ -368,16 +376,24 @@ class FakeQuantWithMinMax(Cell):
return s
def construct(self, x):
if self.is_ascend and self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
if self.is_ascend:
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant(x, self.minq, self.maxq)
else:
out = self.fake_quant(x, self.minq, self.maxq)
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant_train(x, min_up, max_up)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
class Conv2dBatchNormQuant(Cell):
r"""
2D convolution with BatchNormal op folded layer.

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate bprop for aware quantization ops"""
"""Generate bprop for quantization aware ops"""
from .. import operations as P
from ..operations import _quant_ops as Q
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)

View File

@ -14,7 +14,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
"""MinMaxUpdatePerChannel op"""
import te.lang.cce
from te import tvm
from te.platform.fusion_manager import fusion_manager
@ -23,7 +23,7 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerLayerUpdate op"""
"""MinMaxUpdatePerLayer op"""
from functools import reduce as functools_reduce
import te.lang.cce
from te import tvm
@ -23,7 +23,7 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \
@ -48,14 +48,14 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
@op_info_register(fake_quant_minmax_update_op_info)
def _fake_quant_minmax_update_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
"""MinMaxUpdatePerLayer TBE register"""
return
@fusion_manager.register("fake_quant_minmax_update")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
kernel_name="fake_quant_minmax_update"):
"""FakeQuantMinMaxPerLayerUpdate compute"""
"""MinMaxUpdatePerLayer compute"""
shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)

View File

@ -25,8 +25,8 @@ __all__ = ["FakeQuantPerLayer",
"FakeQuantPerLayerGrad",
"FakeQuantPerChannel",
"FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"BatchNormFold",
"BatchNormFoldGrad",
"CorrectionMul",
@ -47,11 +47,11 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
simulate quantization aware funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -834,12 +834,12 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
return dout_type, dout_type
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
@ -858,14 +858,14 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
"""init MinMaxUpdatePerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
@ -907,12 +907,12 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
@ -932,14 +932,14 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
"""init MinMaxUpdatePerChannel OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:

View File

@ -1932,7 +1932,7 @@ class Eye(PrimitiveWithInfer):
Inputs:
- **n** (int) - Number of rows of returned tensor
- **m** (int) - Number of columns of returned tensor
- **t** (mindspore.dtype) - Mindspore's dtype, The data type of the returned tensor.
- **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor.
Outputs:
Tensor, a tensor with ones on the diagonal and zeros elsewhere.

View File

@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num)
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
@ -88,6 +88,6 @@ class LossMonitor(Callback):
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
cur_step_in_epoch, cb_params.batch_num,
cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses),
step_mseconds), flush=True)

View File

@ -15,10 +15,10 @@
"""
quantization.
User can use aware quantization to train a model. Mindspore supports quantization aware training,
User can use quantization aware to train a model. MindSpore supports quantization aware training,
which models quantization errors in both the forward and backward passes using fake-quantization
ops. Note that the entire computation is carried out in floating point. At the end of quantization
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
"""
from .quant import convert_quant_network

View File

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""aware quantization."""
"""quantization aware."""
import copy
import re
import numpy as np
import mindspore.context as context
from ... import log as logger
from ... import nn, ops
@ -234,7 +235,7 @@ class ConvertToQuantNetwork:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_delay,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
@ -403,29 +404,30 @@ def convert_quant_network(network,
narrow_range=(False, False)
):
r"""
Create aware quantizaiton training network.
Create quantization aware training network.
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
quant_delay (int or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: [8, 8]
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: [False, False]
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
and second element represent data flow. Default: (False, False)
symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second
element represent data flow. Default: [False, False]
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
element represent data flow. Default: (False, False)
narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: [False, False]
second element represent data flow. Default: (False, False)
Returns:
Cell, Network which has change to aware quantization training network cell.
Cell, Network which has change to quantization aware training network cell.
"""
support_device = ["Ascend", "GPU"]
def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple):
value = [value]
@ -439,6 +441,9 @@ def convert_quant_network(network,
symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range)
if context.get_context('device_target') not in support_device:
raise KeyError("Not support {} backend.".format(context.get_context('device_target')))
net = ConvertToQuantNetwork(network=network,
quant_delay=quant_delay,
bn_fold=bn_fold,

View File

@ -30,7 +30,7 @@ MS_IMAGE_TENSOR_FORMAT = 'NCHW'
# Set the Event mark
EVENT_FILE_NAME_MARK = ".out.events.summary."
# Set the init event of version and mark
EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:"
EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:"
EVENT_FILE_INIT_VERSION = 1
F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max

View File

@ -2,13 +2,13 @@
## Description
Training LeNet with MNIST dataset in MindSpore with quantization aware trainging.
Training LeNet with MNIST dataset in MindSpore with quantization aware training.
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
In this tutorial, you will:
1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
@ -24,10 +24,10 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http
```python
pip uninstall -y mindspore-ascend
pip uninstall -y mindspore-gpu
pip install mindspore-ascend-0.4.0.whl
pip install mindspore-ascend.whl
```
then you will get the following display
Then you will get the following display
```bash
@ -87,7 +87,7 @@ class LeNet5(nn.Cell):
return x
```
get the MNIST from scratch dataset.
Get the MNIST from scratch dataset.
```Python
ds_train = create_dataset(os.path.join(args.data_path, "train"),
@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size()
### Train model
Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
```Python
# Define the network
@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following:
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```
To save your time, just run this command.
Also, you can just run this command instead.
```python
python train.py --data_path MNIST_Data --device_target Ascend
@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network)
```
### load checkpoint
after convert to quantization aware network, we can load the checkpoint file.
After convert to quantization aware network, we can load the checkpoint file.
```python
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model
To save your time, just run this command.
Also, you can just run this command instread.
```python
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
@ -210,18 +210,18 @@ Procedure of quantization aware model evaluation is different from normal. Becau
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network
```
To save your time, just run this command.
Also, you can just run this command insread.
```python
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
```
The top1 accuracy would display on shell.

View File

@ -50,7 +50,7 @@ if __name__ == "__main__":
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# convert funsion netwrok to aware quantizaiton network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@ -60,7 +60,7 @@ if __name__ == "__main__":
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)

View File

@ -50,10 +50,10 @@ if __name__ == "__main__":
# define funsion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")

View File

@ -18,14 +18,14 @@
import sys
class _MindsporeTestFrameworkkeyword:
class _MindSporeTestFrameworkkeyword:
def __setattr__(self, name, value):
if name in self.__dict__:
raise TypeError("can not rebind keyword (%s)" % name)
self.__dict__[name] = value
keyword = _MindsporeTestFrameworkkeyword()
keyword = _MindSporeTestFrameworkkeyword()
keyword.function = "function"
keyword.inputs = "inputs"