forked from mindspore-Ecosystem/mindspore
!2405 change some comment name in the whole project
Merge pull request !2405 from chenzhongming/master
This commit is contained in:
commit
f1a69de0b6
2
build.sh
2
build.sh
|
@ -248,7 +248,7 @@ checkopts()
|
|||
done
|
||||
}
|
||||
checkopts "$@"
|
||||
echo "---------------- mindspore: build start ----------------"
|
||||
echo "---------------- MindSpore: build start ----------------"
|
||||
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
|
||||
git submodule update --init graphengine
|
||||
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
|
||||
|
|
|
@ -36,7 +36,7 @@ class Monitor {
|
|||
~Monitor() = default;
|
||||
|
||||
// Functor for Perf Monitor main loop.
|
||||
// This function will be the entry point of Mindspore::Dataset::Task
|
||||
// This function will be the entry point of mindspore::Dataset::Task
|
||||
Status operator()();
|
||||
|
||||
int64_t GetSamplingInterval() { return sampling_interval_; }
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
|
||||
// brief mindspore namespace.
|
||||
//
|
||||
// mindspore namespace is the top level namespace of Mindsporeession project.
|
||||
// mindspore namespace is the top level namespace of MindSpore project.
|
||||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ using mindspore::device::DeviceAddress;
|
|||
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
|
||||
// brief mindspore namespace.
|
||||
//
|
||||
// mindspore namespace is the top level namespace of Mindsporeession project.
|
||||
// mindspore namespace is the top level namespace of MindSpore project.
|
||||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
// brief mindspore::tensor namespace
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/pair.h>
|
||||
#include "fake_quant_per_channel_impl.cuh"
|
||||
#include "fake_quant_perchannel_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
/**
|
||||
|
@ -113,44 +113,6 @@ void CalFakeQuantizePerChannel(const float *input, float *output, const int tota
|
|||
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
|
||||
}
|
||||
|
||||
/**
|
||||
* UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA.
|
||||
* @param input_min
|
||||
* @param input_max
|
||||
* @param min
|
||||
* @param max
|
||||
* @return
|
||||
*/
|
||||
__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels,
|
||||
int per_channel_nums, bool ema, float ema_decay) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
|
||||
thrust::pair<float *, float *> sum =
|
||||
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
|
||||
if (ema) {
|
||||
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
|
||||
input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
|
||||
} else {
|
||||
input_min[i] = sum.first[0];
|
||||
input_max[i] = sum.second[0];
|
||||
}
|
||||
input_min[i] = input_min[i] > 0 ? 0 : input_min[i];
|
||||
input_max[i] = input_max[i] < 0 ? 0 : input_max[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max,
|
||||
const float decay) {
|
||||
*input_min = decay * (min) + (1 - decay) * (*input_min);
|
||||
*input_max = decay * (max) + (1 - decay) * (*input_max);
|
||||
}
|
||||
|
||||
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size,
|
||||
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
|
||||
int per_channel_num = total_size / channel_size;
|
||||
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
|
||||
}
|
||||
|
||||
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
|
||||
const int total_size, const int channel_size, const float *nudge_min,
|
||||
const float *nudge_max) {
|
|
@ -18,7 +18,7 @@
|
|||
#include <thrust/device_vector.h>
|
||||
#include <thrust/pair.h>
|
||||
#include "device/gpu/cuda_common.h"
|
||||
#include "fake_quant_impl.cuh"
|
||||
#include "fake_quant_perlayer_impl.cuh"
|
||||
|
||||
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
|
||||
const float *nudge_max, const float *scale) {
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/pair.h>
|
||||
#include "minmax_update_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
|
||||
float *output_max, const float min, const float max, const float decay,
|
||||
const float symmetric) {
|
||||
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
|
||||
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
|
||||
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
|
||||
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
|
||||
|
||||
if (symmetric) {
|
||||
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
|
||||
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max,
|
||||
const float symmetric) {
|
||||
output_min[0] = min > 0 ? 0 : min;
|
||||
output_max[0] = max < 0 ? 0 : max;
|
||||
|
||||
if (symmetric) {
|
||||
output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0];
|
||||
output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
|
||||
float *output_max, int channels, int per_channel_nums, bool ema,
|
||||
float ema_decay, bool symmetric) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
|
||||
thrust::pair<float *, float *> sum =
|
||||
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
|
||||
if (ema) {
|
||||
output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
|
||||
output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
|
||||
} else {
|
||||
output_min[i] = sum.first[0];
|
||||
output_max[i] = sum.second[0];
|
||||
}
|
||||
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
|
||||
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
|
||||
|
||||
if (symmetric) {
|
||||
output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i];
|
||||
output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
|
||||
const int total_num, const int channel_num, const float ema_decay, const bool ema,
|
||||
const bool symmetric, cudaStream_t cuda_stream) {
|
||||
int per_channel_num = total_num / channel_num;
|
||||
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
|
||||
const int total_num, const float ema_decay, const bool ema, const bool symmetric,
|
||||
cudaStream_t cuda_stream) {
|
||||
float minel = 0.f;
|
||||
float maxel = 0.f;
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
|
||||
tuple =
|
||||
thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num);
|
||||
minel = tuple.first[0];
|
||||
maxel = tuple.second[0];
|
||||
|
||||
if (ema) {
|
||||
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
|
||||
maxel, ema_decay, symmetric);
|
||||
} else {
|
||||
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric);
|
||||
}
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
|
||||
const int total_num, const int channel_num, const float ema_decay, const bool ema,
|
||||
const bool symmetric, cudaStream_t cuda_stream);
|
||||
|
||||
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
|
||||
const int size, const float ema_decay, const bool ema, const bool symmetric,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
|
||||
#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/pair.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
@ -25,21 +25,15 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel()
|
||||
: input_size_(0),
|
||||
min_size_(0),
|
||||
max_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
num_channels_(0),
|
||||
num_bits_(0),
|
||||
training_(false),
|
||||
symmetric_(false),
|
||||
narrow_range_(false),
|
||||
quant_delay_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
quant_delay_(0),
|
||||
ema_(false),
|
||||
ema_decay_(0),
|
||||
global_step_(0),
|
||||
training_(false),
|
||||
channel_out_(0),
|
||||
narrow_range_(false),
|
||||
symmetric_(false) {}
|
||||
global_step_(0) {}
|
||||
|
||||
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
|
@ -60,90 +54,56 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// get attribute
|
||||
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
|
||||
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
|
||||
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
|
||||
return false;
|
||||
}
|
||||
|
||||
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
|
||||
if (quant_delay_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
|
||||
return false;
|
||||
}
|
||||
|
||||
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
|
||||
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
if (symmetric_) {
|
||||
quant_min_ = 0 - (1 << (num_bits_ - 1));
|
||||
quant_max_ = (1 << (num_bits_ - 1)) - 1;
|
||||
} else {
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
}
|
||||
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
// quant min and max value
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
// shape info for gpu
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
channel_out_ = SizeToInt(input_shape[0]);
|
||||
min_size_ = sizeof(float) * channel_out_;
|
||||
max_size_ = sizeof(float) * channel_out_;
|
||||
num_channels_ = SizeToInt(input_shape[0]);
|
||||
input_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // input in tensor
|
||||
input_size_list_.push_back(min_size_); // min one scalar
|
||||
input_size_list_.push_back(max_size_); // max on scalar
|
||||
output_size_list_.push_back(output_size_); // output in tensor
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
|
||||
input_size_list_.push_back(input_size_); // input in tensor
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar
|
||||
output_size_list_.push_back(input_size_); // output in tensor
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
|
||||
}
|
||||
|
||||
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
|
||||
float *input_max, float *d_nudge_min, float *d_nudge_max,
|
||||
float *d_scale, void *stream_ptr) {
|
||||
// calculate the input min and max according by the parameter ema and ema_decay.
|
||||
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// control flow for quant_delay
|
||||
if (global_step_ >= quant_delay_) {
|
||||
// real launch
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
|
||||
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(
|
||||
cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed.");
|
||||
}
|
||||
global_step_++;
|
||||
}
|
||||
|
||||
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
|
||||
float *input_max, float *d_nudge_min, float *d_nudge_max,
|
||||
float *d_scale, void *stream_ptr) {
|
||||
// real launch
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
||||
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
|
||||
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
|
||||
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
|
||||
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
|
@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 2);
|
||||
float *d_scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
float *scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
|
||||
|
@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
|
||||
if (training_) {
|
||||
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
||||
if (global_step_ >= quant_delay_) {
|
||||
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed.");
|
||||
}
|
||||
global_step_++;
|
||||
} else {
|
||||
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
||||
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
|
||||
}
|
||||
|
||||
return true;
|
|
@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
|
|||
void InitSizeLists() override;
|
||||
|
||||
private:
|
||||
void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
|
||||
float *d_nudge_max, float *d_scale, void *stream_ptr);
|
||||
void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
|
||||
float *d_nudge_max, float *d_scale, void *stream_ptr);
|
||||
void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min,
|
||||
float *nudge_max, float *scale, void *stream_ptr);
|
||||
|
||||
size_t input_size_;
|
||||
size_t min_size_;
|
||||
size_t max_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int num_channels_;
|
||||
int num_bits_;
|
||||
bool training_;
|
||||
bool symmetric_;
|
||||
bool narrow_range_;
|
||||
int quant_delay_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int quant_delay_;
|
||||
bool ema_;
|
||||
float ema_decay_;
|
||||
int global_step_;
|
||||
bool training_;
|
||||
int channel_out_;
|
||||
bool narrow_range_;
|
||||
bool symmetric_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,21 +14,17 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
|
||||
#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel()
|
||||
: input_size_(0),
|
||||
min_size_(0),
|
||||
max_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
num_bits_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
channel_out_(0),
|
||||
num_channels_(0),
|
||||
quant_delay_(0),
|
||||
global_step_(0),
|
||||
narrow_range_(false),
|
||||
|
@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
if (symmetric_) {
|
||||
quant_min_ = 0 - (1 << (num_bits_ - 1));
|
||||
quant_max_ = (1 << (num_bits_ - 1)) - 1;
|
||||
} else {
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
}
|
||||
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
|
||||
// quant min and max value
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
channel_out_ = SizeToInt(input_shape[0]);
|
||||
min_size_ = sizeof(float) * channel_out_;
|
||||
max_size_ = sizeof(float) * channel_out_;
|
||||
num_channels_ = SizeToInt(input_shape[0]);
|
||||
input_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // gradient
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(min_size_); // min
|
||||
input_size_list_.push_back(max_size_); // max
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
|
||||
input_size_list_.push_back(input_size_); // gradient
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // min
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // max
|
||||
output_size_list_.push_back(input_size_); // output
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
|
||||
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
|
||||
}
|
||||
|
||||
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
||||
|
@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
|
|||
float *input = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 2);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 3);
|
||||
float *d_scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
float *scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (gradient == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
|
||||
|
@ -130,9 +118,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
|
|||
|
||||
int total_size = input_size_ / sizeof(float);
|
||||
if (global_step_ >= quant_delay_) {
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
||||
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
|
||||
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
|
|
@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t min_size_;
|
||||
size_t max_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
|
|||
int num_bits_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int channel_out_;
|
||||
int num_channels_;
|
||||
int quant_delay_;
|
||||
int global_step_;
|
||||
bool narrow_range_;
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
|
||||
#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/pair.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
@ -23,31 +23,25 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
FakeQuantGpuKernel::FakeQuantGpuKernel()
|
||||
FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel()
|
||||
: input_size_(0),
|
||||
min_size_(0),
|
||||
max_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
num_bits_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
quant_num_(0),
|
||||
quant_delay_(0),
|
||||
ema_(false),
|
||||
ema_decay_(0),
|
||||
quant_num_(1),
|
||||
global_step_(0),
|
||||
num_bits_(0),
|
||||
quant_delay_(0),
|
||||
training_(false),
|
||||
narrow_range_(false),
|
||||
symmetric_(false) {}
|
||||
|
||||
const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
|
||||
|
@ -59,95 +53,73 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
|
||||
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
|
||||
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
|
||||
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
|
||||
}
|
||||
|
||||
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
|
||||
if (quant_delay_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
|
||||
}
|
||||
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
if (symmetric_) {
|
||||
quant_min_ = 0 - (1 << (num_bits_ - 1));
|
||||
quant_max_ = (1 << (num_bits_ - 1)) - 1;
|
||||
} else {
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
}
|
||||
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
// quant min and max value
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
if (quant_num_ == 0) {
|
||||
quant_num_ = 1;
|
||||
}
|
||||
// init size
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
quant_num_ *= SizeToInt(input_shape[i]);
|
||||
}
|
||||
|
||||
input_size_ = sizeof(float);
|
||||
min_size_ = sizeof(float);
|
||||
max_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void FakeQuantGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(min_size_); // min
|
||||
input_size_list_.push_back(max_size_); // max
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
void FakeQuantPerLayerGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // x
|
||||
input_size_list_.push_back(sizeof(float)); // min
|
||||
input_size_list_.push_back(sizeof(float)); // max
|
||||
output_size_list_.push_back(input_size_); // y
|
||||
workspace_size_list_.push_back(sizeof(float)); // scale
|
||||
workspace_size_list_.push_back(sizeof(float)); // nudge_min
|
||||
workspace_size_list_.push_back(sizeof(float)); // nudge_max
|
||||
}
|
||||
|
||||
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
float *output = GetDeviceAddress<float>(outputs, 0);
|
||||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 2);
|
||||
float *scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null.";
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null.";
|
||||
}
|
||||
if (input_min == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null.";
|
||||
if (input_min == nullptr || input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null.";
|
||||
}
|
||||
if (input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null.";
|
||||
}
|
||||
|
||||
// Allocate space for device copies
|
||||
int size = sizeof(float);
|
||||
float *d_scale = nullptr;
|
||||
float *d_nudge_min = nullptr;
|
||||
float *d_nudge_max = nullptr;
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
|
||||
|
||||
if (training_) {
|
||||
// calculate the input min and max according by the parameter ema and ema_decay.
|
||||
CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
// control flow for quant_delay
|
||||
if (global_step_ >= quant_delay_) {
|
||||
// real launch
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
|
||||
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
|
||||
|
@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
global_step_++;
|
||||
} else {
|
||||
// real launch
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
|
||||
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
|
||||
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
|
@ -23,10 +23,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class FakeQuantGpuKernel : public GpuKernel {
|
||||
class FakeQuantPerLayerGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FakeQuantGpuKernel();
|
||||
~FakeQuantGpuKernel() = default;
|
||||
FakeQuantPerLayerGpuKernel();
|
||||
~FakeQuantPerLayerGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
|
@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t min_size_;
|
||||
size_t max_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int num_bits_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int quant_num_;
|
||||
int quant_delay_;
|
||||
bool ema_;
|
||||
float ema_decay_;
|
||||
int global_step_;
|
||||
int num_bits_;
|
||||
int quant_delay_;
|
||||
bool training_;
|
||||
bool narrow_range_;
|
||||
bool symmetric_;
|
||||
|
@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
|
|
@ -14,33 +14,30 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
|
||||
#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
|
||||
FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel()
|
||||
: input_size_(0),
|
||||
min_size_(0),
|
||||
max_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
num_bits_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
quant_size_(0),
|
||||
quant_num_(1),
|
||||
quant_delay_(0),
|
||||
global_step_(0),
|
||||
narrow_range_(false),
|
||||
symmetric_(false) {}
|
||||
|
||||
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 4) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
|
||||
|
@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
if (symmetric_) {
|
||||
quant_min_ = 0 - (1 << (num_bits_ - 1));
|
||||
quant_max_ = (1 << (num_bits_ - 1)) - 1;
|
||||
} else {
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
}
|
||||
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
|
||||
// quant min and max value
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
if (quant_size_ == 0) {
|
||||
quant_size_ = 1;
|
||||
}
|
||||
// init size
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
quant_size_ *= SizeToInt(input_shape[i]);
|
||||
quant_num_ *= SizeToInt(input_shape[i]);
|
||||
}
|
||||
|
||||
input_size_ = sizeof(float);
|
||||
min_size_ = sizeof(float);
|
||||
max_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void FakeQuantGradGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // gradient
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(min_size_); // min
|
||||
input_size_list_.push_back(max_size_); // max
|
||||
output_size_list_.push_back(output_size_);
|
||||
void FakeQuantPerLayerGradGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // gradient
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(sizeof(float)); // min
|
||||
input_size_list_.push_back(sizeof(float)); // max
|
||||
output_size_list_.push_back(input_size_); // output
|
||||
workspace_size_list_.push_back(sizeof(float)); // scale
|
||||
workspace_size_list_.push_back(sizeof(float)); // nudge_min
|
||||
workspace_size_list_.push_back(sizeof(float)); // nudge_max
|
||||
}
|
||||
|
||||
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
float *output = GetDeviceAddress<float>(outputs, 0);
|
||||
float *gradient = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 2);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 3);
|
||||
float *scale = GetDeviceAddress<float>(workspace, 0);
|
||||
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
|
||||
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
|
||||
|
||||
if (gradient == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null";
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null";
|
||||
}
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null.";
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null.";
|
||||
}
|
||||
if (input_min == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null.";
|
||||
}
|
||||
if (input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null.";
|
||||
if (input_min == nullptr || input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null.";
|
||||
}
|
||||
|
||||
if (global_step_ >= quant_delay_) {
|
||||
float *d_scale = nullptr;
|
||||
float *d_nudge_min = nullptr;
|
||||
float *d_nudge_max = nullptr;
|
||||
int size = sizeof(float);
|
||||
// Allocate space for device copies
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
|
||||
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
|
||||
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max,
|
||||
CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// Cleanup
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
|
@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
return true;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
|
||||
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
|
@ -23,10 +23,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class FakeQuantGradGpuKernel : public GpuKernel {
|
||||
class FakeQuantPerLayerGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FakeQuantGradGpuKernel();
|
||||
~FakeQuantGradGpuKernel() = default;
|
||||
FakeQuantPerLayerGradGpuKernel();
|
||||
~FakeQuantPerLayerGradGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
|
@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel {
|
|||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t min_size_;
|
||||
size_t max_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel {
|
|||
int num_bits_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int quant_size_;
|
||||
int quant_num_;
|
||||
int quant_delay_;
|
||||
int global_step_;
|
||||
bool narrow_range_;
|
||||
|
@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
|
|
@ -0,0 +1,119 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/pair.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
|
||||
: input_size_(0),
|
||||
num_bits_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
quant_num_(1),
|
||||
ema_(false),
|
||||
ema_decay_(0),
|
||||
num_channels_(0),
|
||||
narrow_range_(false),
|
||||
symmetric_(false) {}
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const {
|
||||
return workspace_size_list_;
|
||||
}
|
||||
|
||||
bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
|
||||
}
|
||||
|
||||
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
|
||||
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
|
||||
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
|
||||
}
|
||||
|
||||
// quant min and max
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
// init size
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
num_channels_ = SizeToInt(input_shape[0]);
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
quant_num_ *= SizeToInt(input_shape[i]);
|
||||
}
|
||||
input_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // min
|
||||
input_size_list_.push_back(sizeof(float) * num_channels_); // max
|
||||
output_size_list_.push_back(sizeof(float) * num_channels_); // output min
|
||||
output_size_list_.push_back(sizeof(float) * num_channels_); // output max
|
||||
}
|
||||
|
||||
bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
float *output_min = GetDeviceAddress<float>(outputs, 0);
|
||||
float *output_max = GetDeviceAddress<float>(outputs, 1);
|
||||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 2);
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null.";
|
||||
}
|
||||
if (input_min == nullptr || input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null.";
|
||||
}
|
||||
|
||||
// calculate the input min and max according by the parameter ema and ema_decay.
|
||||
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
|
||||
ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MinMaxUpdatePerChannelGpuKernel();
|
||||
~MinMaxUpdatePerChannelGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
bool Init(const CNodePtr &kernel) override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override;
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int num_bits_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int quant_num_;
|
||||
bool ema_;
|
||||
float ema_decay_;
|
||||
int num_channels_;
|
||||
bool narrow_range_;
|
||||
bool symmetric_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h"
|
||||
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/pair.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
|
||||
: input_size_(0),
|
||||
num_bits_(0),
|
||||
quant_min_(0),
|
||||
quant_max_(0),
|
||||
quant_num_(1),
|
||||
ema_(false),
|
||||
ema_decay_(0),
|
||||
narrow_range_(false),
|
||||
symmetric_(false) {}
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
|
||||
|
||||
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||
|
||||
bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
|
||||
}
|
||||
|
||||
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
|
||||
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
|
||||
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
|
||||
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
|
||||
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
|
||||
|
||||
if (num_bits_ <= 2 || num_bits_ >= 16) {
|
||||
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
|
||||
}
|
||||
|
||||
// quant min and max
|
||||
quant_min_ = 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
if (narrow_range_) {
|
||||
quant_min_++;
|
||||
}
|
||||
|
||||
// init size
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
quant_num_ *= SizeToInt(input_shape[i]);
|
||||
}
|
||||
input_size_ = sizeof(float);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() {
|
||||
input_size_list_.push_back(input_size_); // input
|
||||
input_size_list_.push_back(sizeof(float)); // input min
|
||||
input_size_list_.push_back(sizeof(float)); // input max
|
||||
output_size_list_.push_back(sizeof(float)); // output min
|
||||
output_size_list_.push_back(sizeof(float)); // output max
|
||||
}
|
||||
|
||||
bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
float *output_min = GetDeviceAddress<float>(outputs, 0);
|
||||
float *output_max = GetDeviceAddress<float>(outputs, 1);
|
||||
float *input = GetDeviceAddress<float>(inputs, 0);
|
||||
float *input_min = GetDeviceAddress<float>(inputs, 1);
|
||||
float *input_max = GetDeviceAddress<float>(inputs, 2);
|
||||
|
||||
if (input == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null.";
|
||||
}
|
||||
if (input_min == nullptr || input_max == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
|
||||
}
|
||||
|
||||
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MinMaxUpdatePerLayerGpuKernel();
|
||||
~MinMaxUpdatePerLayerGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override;
|
||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
bool Init(const CNodePtr &kernel) override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override;
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
int num_bits_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
int quant_num_;
|
||||
bool ema_;
|
||||
float ema_decay_;
|
||||
bool narrow_range_;
|
||||
bool symmetric_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
|
|
@ -28,7 +28,7 @@ message LineageEvent {
|
|||
|
||||
oneof what {
|
||||
// An event file was started, with the specified version.
|
||||
// Now version is "Mindspore.Event:1"
|
||||
// Now version is "MindSpore.Event:1"
|
||||
string version = 3;
|
||||
|
||||
// Train lineage
|
||||
|
|
|
@ -32,7 +32,7 @@ message Event {
|
|||
|
||||
oneof what {
|
||||
// An event file was started, with the specified version.
|
||||
// Now version is "Mindspore.Event:1"
|
||||
// Now version is "MindSpore.Event:1"
|
||||
string version = 3;
|
||||
|
||||
// GraphDef.
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "vm/segment_runner.h"
|
||||
#include "vm/backend.h"
|
||||
|
||||
// mindspore namespace is the top level namespace of Mindsporeession project.
|
||||
// mindspore namespace is the top level namespace of MindSpore project.
|
||||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
extern const char kMsVm[];
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Aware quantization."""
|
||||
"""Quantization aware."""
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
@ -172,7 +172,7 @@ class DenseBnAct(Cell):
|
|||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> net = nn.DenseBnAct(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
|
@ -271,7 +271,7 @@ class BatchNormFoldCell(Cell):
|
|||
|
||||
class FakeQuantWithMinMax(Cell):
|
||||
r"""
|
||||
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
|
||||
Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
|
||||
|
||||
Args:
|
||||
min_init (int, float): The dimension of channel or 1(layer). Default: -6.
|
||||
|
@ -338,22 +338,30 @@ class FakeQuantWithMinMax(Cell):
|
|||
# init fake quant relative op
|
||||
if per_channel:
|
||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
quant_fun = Q.FakeQuantPerLayer
|
||||
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate
|
||||
ema_fun = Q.MinMaxUpdatePerLayer
|
||||
|
||||
if self.is_ascend:
|
||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
else:
|
||||
self.fake_quant = quant_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range)
|
||||
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=True)
|
||||
self.fake_quant_infer = quant_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_delay=quant_delay,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
training=False)
|
||||
self.ema_update = ema_fun(num_bits=self.num_bits,
|
||||
ema=self.ema,
|
||||
ema_decay=self.ema_decay,
|
||||
|
@ -368,16 +376,24 @@ class FakeQuantWithMinMax(Cell):
|
|||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.is_ascend and self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
out = self.fake_quant(x, min_up, max_up)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
if self.is_ascend:
|
||||
if self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
out = self.fake_quant(x, min_up, max_up)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
else:
|
||||
out = self.fake_quant(x, self.minq, self.maxq)
|
||||
else:
|
||||
out = self.fake_quant(x, self.minq, self.maxq)
|
||||
if self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
out = self.fake_quant_train(x, min_up, max_up)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
else:
|
||||
out = self.fake_quant_infer(x, self.minq, self.maxq)
|
||||
return out
|
||||
|
||||
|
||||
class Conv2dBatchNormQuant(Cell):
|
||||
r"""
|
||||
2D convolution with BatchNormal op folded layer.
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Generate bprop for aware quantization ops"""
|
||||
"""Generate bprop for quantization aware ops"""
|
||||
|
||||
from .. import operations as P
|
||||
from ..operations import _quant_ops as Q
|
||||
|
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate)
|
||||
@bprop_getters.register(Q.MinMaxUpdatePerLayer)
|
||||
def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
||||
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
|
||||
"""Generate bprop for MinMaxUpdatePerLayer for Ascend"""
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||
|
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate)
|
||||
@bprop_getters.register(Q.MinMaxUpdatePerChannel)
|
||||
def get_bprop_fakequant_with_minmax_per_channel_update(self):
|
||||
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
|
||||
"""Generate bprop for MinMaxUpdatePerChannel for Ascend"""
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantMinMaxPerChannelUpdate op"""
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
|
@ -23,7 +23,7 @@ from topi.cce import util
|
|||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
|
||||
fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_min_max_per_channel_update.so") \
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantMinMaxPerLayerUpdate op"""
|
||||
"""MinMaxUpdatePerLayer op"""
|
||||
from functools import reduce as functools_reduce
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
|
@ -23,7 +23,7 @@ from topi.cce import util
|
|||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
||||
fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_minmax_update.so") \
|
||||
|
@ -48,14 +48,14 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
|
|||
|
||||
@op_info_register(fake_quant_minmax_update_op_info)
|
||||
def _fake_quant_minmax_update_tbe():
|
||||
"""FakeQuantMinMaxPerLayerUpdate TBE register"""
|
||||
"""MinMaxUpdatePerLayer TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_minmax_update")
|
||||
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
|
||||
kernel_name="fake_quant_minmax_update"):
|
||||
"""FakeQuantMinMaxPerLayerUpdate compute"""
|
||||
"""MinMaxUpdatePerLayer compute"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||
|
|
|
@ -25,8 +25,8 @@ __all__ = ["FakeQuantPerLayer",
|
|||
"FakeQuantPerLayerGrad",
|
||||
"FakeQuantPerChannel",
|
||||
"FakeQuantPerChannelGrad",
|
||||
"FakeQuantMinMaxPerLayerUpdate",
|
||||
"FakeQuantMinMaxPerChannelUpdate",
|
||||
"MinMaxUpdatePerLayer",
|
||||
"MinMaxUpdatePerChannel",
|
||||
"BatchNormFold",
|
||||
"BatchNormFoldGrad",
|
||||
"CorrectionMul",
|
||||
|
@ -47,11 +47,11 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|||
Simulate the quantize and dequantize operations in training time.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
num_bits (int) : Number bits for quantization aware. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
|
||||
simulate aware quantize funcion. After delay step in training time begin simulate the aware
|
||||
simulate quantization aware funcion. After delay step in training time begin simulate the aware
|
||||
quantize funcion. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
|
@ -834,12 +834,12 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
|||
return dout_type, dout_type
|
||||
|
||||
|
||||
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
||||
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max value for fake quant per layer op.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
num_bits (int) : Number bits for quantization aware. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
|
@ -858,14 +858,14 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
|||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
||||
"""init MinMaxUpdatePerLayer OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
|
@ -907,12 +907,12 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
|||
return min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
|
||||
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||
r"""
|
||||
Update min and max value for fake quant per layer op.
|
||||
|
||||
Args:
|
||||
num_bits (int) : Number bits for aware quantilization. Default: 8.
|
||||
num_bits (int) : Number bits for quantization aware. Default: 8.
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
|
@ -932,14 +932,14 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
|
|||
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
|
||||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True, channel_axis=1):
|
||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||
"""init MinMaxUpdatePerChannel OP for Ascend"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
|
|
|
@ -1932,7 +1932,7 @@ class Eye(PrimitiveWithInfer):
|
|||
Inputs:
|
||||
- **n** (int) - Number of rows of returned tensor
|
||||
- **m** (int) - Number of columns of returned tensor
|
||||
- **t** (mindspore.dtype) - Mindspore's dtype, The data type of the returned tensor.
|
||||
- **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor with ones on the diagonal and zeros elsewhere.
|
||||
|
|
|
@ -76,7 +76,7 @@ class LossMonitor(Callback):
|
|||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num)
|
||||
|
||||
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
|
||||
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
|
||||
|
@ -88,6 +88,6 @@ class LossMonitor(Callback):
|
|||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
||||
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
|
||||
cur_step_in_epoch, cb_params.batch_num,
|
||||
cur_step_in_epoch, int(cb_params.batch_num),
|
||||
step_loss, np.mean(self.losses),
|
||||
step_mseconds), flush=True)
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
"""
|
||||
quantization.
|
||||
|
||||
User can use aware quantization to train a model. Mindspore supports quantization aware training,
|
||||
User can use quantization aware to train a model. MindSpore supports quantization aware training,
|
||||
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
ops. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
|
||||
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
|
||||
"""
|
||||
|
||||
from .quant import convert_quant_network
|
||||
|
|
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""aware quantization."""
|
||||
"""quantization aware."""
|
||||
|
||||
import copy
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
||||
from ... import log as logger
|
||||
from ... import nn, ops
|
||||
|
@ -234,7 +235,7 @@ class ConvertToQuantNetwork:
|
|||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
num_bits=self.act_bits,
|
||||
quant_delay=self.act_delay,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
|
@ -403,29 +404,30 @@ def convert_quant_network(network,
|
|||
narrow_range=(False, False)
|
||||
):
|
||||
r"""
|
||||
Create aware quantizaiton training network.
|
||||
Create quantization aware training network.
|
||||
|
||||
Args:
|
||||
network (Cell): Obtain a pipeline through network for saving graph summary.
|
||||
quant_delay (int): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
|
||||
quant_delay (int or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
|
||||
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
|
||||
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
|
||||
element represent weights and second element represent data flow. Default: [8, 8]
|
||||
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
|
||||
num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first
|
||||
element represent weights and second element represent data flow. Default: (8, 8)
|
||||
per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: [False, False]
|
||||
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on assymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: [False, False]
|
||||
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base
|
||||
on narrow range otherwise base on off narrow range. The first element represent weights and
|
||||
second element represent data flow. Default: [False, False]
|
||||
second element represent data flow. Default: (False, False)
|
||||
|
||||
Returns:
|
||||
Cell, Network which has change to aware quantization training network cell.
|
||||
Cell, Network which has change to quantization aware training network cell.
|
||||
"""
|
||||
support_device = ["Ascend", "GPU"]
|
||||
def convert2list(name, value):
|
||||
if not isinstance(value, list) and not isinstance(value, tuple):
|
||||
value = [value]
|
||||
|
@ -439,6 +441,9 @@ def convert_quant_network(network,
|
|||
symmetric = convert2list("symmetric", symmetric)
|
||||
narrow_range = convert2list("narrow range", narrow_range)
|
||||
|
||||
if context.get_context('device_target') not in support_device:
|
||||
raise KeyError("Not support {} backend.".format(context.get_context('device_target')))
|
||||
|
||||
net = ConvertToQuantNetwork(network=network,
|
||||
quant_delay=quant_delay,
|
||||
bn_fold=bn_fold,
|
||||
|
|
|
@ -30,7 +30,7 @@ MS_IMAGE_TENSOR_FORMAT = 'NCHW'
|
|||
# Set the Event mark
|
||||
EVENT_FILE_NAME_MARK = ".out.events.summary."
|
||||
# Set the init event of version and mark
|
||||
EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:"
|
||||
EVENT_FILE_INIT_VERSION_MARK = "MindSpore.Event:"
|
||||
EVENT_FILE_INIT_VERSION = 1
|
||||
|
||||
F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
## Description
|
||||
|
||||
Training LeNet with MNIST dataset in MindSpore with quantization aware trainging.
|
||||
Training LeNet with MNIST dataset in MindSpore with quantization aware training.
|
||||
|
||||
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
|
||||
|
||||
In this tutorial, you will:
|
||||
|
||||
1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
|
||||
1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
|
||||
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
|
||||
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
|
||||
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
|
||||
|
@ -24,10 +24,10 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http
|
|||
```python
|
||||
pip uninstall -y mindspore-ascend
|
||||
pip uninstall -y mindspore-gpu
|
||||
pip install mindspore-ascend-0.4.0.whl
|
||||
pip install mindspore-ascend.whl
|
||||
```
|
||||
|
||||
then you will get the following display
|
||||
Then you will get the following display
|
||||
|
||||
|
||||
```bash
|
||||
|
@ -87,7 +87,7 @@ class LeNet5(nn.Cell):
|
|||
return x
|
||||
```
|
||||
|
||||
get the MNIST from scratch dataset.
|
||||
Get the MNIST from scratch dataset.
|
||||
|
||||
```Python
|
||||
ds_train = create_dataset(os.path.join(args.data_path, "train"),
|
||||
|
@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size()
|
|||
|
||||
### Train model
|
||||
|
||||
Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
|
||||
Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
|
||||
|
||||
```Python
|
||||
# Define the network
|
||||
|
@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following:
|
|||
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||
```
|
||||
|
||||
To save your time, just run this command.
|
||||
Also, you can just run this command instead.
|
||||
|
||||
```python
|
||||
python train.py --data_path MNIST_Data --device_target Ascend
|
||||
|
@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the
|
|||
# define funsion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
||||
# load aware quantizaiton network checkpoint
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# convert funsion netwrok to aware quantizaiton network
|
||||
# convert funsion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network)
|
||||
```
|
||||
|
||||
### load checkpoint
|
||||
|
||||
after convert to quantization aware network, we can load the checkpoint file.
|
||||
After convert to quantization aware network, we can load the checkpoint file.
|
||||
|
||||
```python
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
|
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|||
|
||||
### train quantization aware model
|
||||
|
||||
To save your time, just run this command.
|
||||
Also, you can just run this command instread.
|
||||
|
||||
```python
|
||||
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||
|
@ -210,18 +210,18 @@ Procedure of quantization aware model evaluation is different from normal. Becau
|
|||
# define funsion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
||||
# load aware quantizaiton network checkpoint
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# convert funsion netwrok to aware quantizaiton network
|
||||
# convert funsion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network
|
||||
```
|
||||
|
||||
To save your time, just run this command.
|
||||
Also, you can just run this command insread.
|
||||
|
||||
```python
|
||||
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||
python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||
```
|
||||
|
||||
The top1 accuracy would display on shell.
|
||||
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||
|
||||
# define funsion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert funsion netwrok to aware quantizaiton network
|
||||
# convert funsion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
|||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# load aware quantizaiton network checkpoint
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
|
|
|
@ -50,10 +50,10 @@ if __name__ == "__main__":
|
|||
|
||||
# define funsion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# load aware quantizaiton network checkpoint
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
# convert funsion netwrok to aware quantizaiton network
|
||||
# convert funsion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
|
|
|
@ -18,14 +18,14 @@
|
|||
import sys
|
||||
|
||||
|
||||
class _MindsporeTestFrameworkkeyword:
|
||||
class _MindSporeTestFrameworkkeyword:
|
||||
def __setattr__(self, name, value):
|
||||
if name in self.__dict__:
|
||||
raise TypeError("can not rebind keyword (%s)" % name)
|
||||
self.__dict__[name] = value
|
||||
|
||||
|
||||
keyword = _MindsporeTestFrameworkkeyword()
|
||||
keyword = _MindSporeTestFrameworkkeyword()
|
||||
|
||||
keyword.function = "function"
|
||||
keyword.inputs = "inputs"
|
||||
|
|
Loading…
Reference in New Issue