!2441 add fake quant test case for gpu

Merge pull request !2441 from chenzhongming/master
This commit is contained in:
mindspore-ci-bot 2020-06-23 18:57:21 +08:00 committed by Gitee
commit 8870956954
13 changed files with 1703 additions and 123 deletions

View File

@ -20,7 +20,6 @@
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_perchannel_impl.cuh"
#include "device/gpu/cuda_common.h"
/**
* Find the nudge min, max and scale value as output.
@ -34,13 +33,17 @@
* @param channel_num
* @return
*/
__global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min,
const float quant_max, float *nudge_min, float *nudge_max, float *scale,
int channel_num) {
__global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, int channel_num,
const bool symmetric) {
float zp_from_min = 0.f;
float nudge_zp = 0.f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
if (symmetric) {
input_max[i] = abs(input_min[0]) < input_max[i] ? input_max[i] : -input_min[i];
input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
}
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
scale[i] = 0.f;
zp_from_min = 0.f;
@ -62,11 +65,11 @@ __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input
}
}
void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num,
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream) {
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num);
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num, symmetric);
}
/**
@ -80,9 +83,8 @@ void CalNudgePerChannel(const float *input_min, const float *input_max, const fl
* @param scale - array
* @return
*/
__global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale,
bool symmetric) {
__global__ void FakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale) {
float input_x = 0.f;
int nudge_input = 0;
int channel_idx = 0;
@ -106,16 +108,15 @@ __global__ void FakeQuantizePerChannel(const float *input, float *output, const
}
}
void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric,
cudaStream_t cuda_stream) {
FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
void CalFakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream) {
FakeQuantPerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(input, output, total_size, channel_size,
nudge_min, nudge_max, scale);
}
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
const int total_size, const int channel_size, const float *nudge_min,
const float *nudge_max) {
__global__ void FakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_size,
const int channel_size, const float *nudge_min, const float *nudge_max) {
int channel_idx = 0;
int per_channel_num = total_size / channel_size;
@ -129,9 +130,9 @@ __global__ void FakeQuantizePerChannelGrad(const float *input, const float *grad
}
}
void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream) {
FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, gradient, output, total_num, channel_num, nudge_min, nudge_max);
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream) {
FakeQuantPerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, total_num,
channel_num, nudge_min, nudge_max);
}

View File

@ -14,22 +14,21 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
#include "device/gpu/cuda_common.h"
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream);
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num,
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
cudaStream_t cuda_stream);
void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream);
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num,
const float ema_decay, const bool ema, cudaStream_t cuda_stream);
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream);
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
const int channel_num, const float* nudge_min, const float* nudge_max,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_

View File

@ -17,11 +17,10 @@
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "device/gpu/cuda_common.h"
#include "fake_quant_perlayer_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {
__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {
float input_x = 0.f;
int nudge_input = 0;
@ -43,8 +42,8 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
return;
}
__global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max) {
__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
output[i] = 0;
@ -55,12 +54,18 @@ __global__ void FakeQuantizeGrad(const float *input, const float *gradient, floa
return;
}
__global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min,
const float quant_max, float *nudge_min, float *nudge_max, float *scale) {
__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric) {
float zp_from_min = 0.f;
scale[0] = 0.f;
nudge_max[0] = 0.f;
nudge_min[0] = 0.f;
if (symmetric) {
input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
}
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
scale[0] = 0.f;
zp_from_min = 0.f;
@ -83,53 +88,24 @@ __global__ void NudgeMinMax(const float *input_min, const float *input_max, cons
return;
}
__global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max,
const float decay) {
input_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
input_min[0] = input_min[0] > 0 ? 0 : input_min[0];
input_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
input_max[0] = input_max[0] < 0 ? 0 : input_max[0];
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream) {
FakeQuantPerLayer<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max,
scale);
return;
}
__global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) {
input_min[0] = min > 0 ? 0 : min;
input_max[0] = max < 0 ? 0 : max;
}
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream) {
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale);
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
FakeQuantPerLayerGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
return;
}
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
return;
}
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) {
NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
return;
}
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay);
} else {
UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel);
}
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric,
cudaStream_t cuda_stream) {
NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale,
symmetric);
return;
}

View File

@ -14,19 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream);
#include "device/gpu/cuda_common.h"
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream);
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream);
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream);
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream);
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_

View File

@ -102,9 +102,9 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() {
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,

View File

@ -119,9 +119,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),

View File

@ -117,10 +117,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
@ -129,10 +129,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c
global_step_++;
} else {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;

View File

@ -115,10 +115,10 @@ bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &input
}
if (global_step_ >= quant_delay_) {
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),

View File

@ -150,7 +150,7 @@ class ConvertToQuantNetwork:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_delay=self.act_delay,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
@ -408,19 +408,19 @@ def convert_quant_network(network,
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True`
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second
symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False)
narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base
narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: (False, False)

View File

@ -0,0 +1,625 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self, num_bits=8, symmetric=False, narrow_range=False, channel_axis=1):
super(Net, self).__init__()
self.op = Q.FakeQuantPerChannel(num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
channel_axis=channel_axis)
def construct(self, x, minq, maxq):
return self.op(x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel1():
# WithVarsPerChannel_ZeroMinAndMax
x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel2():
# WithVarsPerChannelDim1NudgedDown_RegularRange
# scale 1/4, zp 0.4, nudge 0. nudged ranges [0.0, 63.75]
x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
expect = np.array([0.0, 0.0, 63.75, 63.75]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel3():
# WithVarsPerChannelDim1NudgedDown_NarrowRange
# scale 1/4, zp 1.4, nudge 1. nudged ranges[0.0, 63.5]
x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
expect = np.array([0.0, 0.0, 63.5, 63.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel4():
# WithVarsPerChannelDim1NudgedUp_RegularRange
# [-0.125, 63.625]
# scale 1/4, zp: 0.5, nudge 0. nudged range [-0.25, 63.5]
x = np.array([-0.26, -0.25, -0.24, 63.6]).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 63.5]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel5():
# WithVarsPerChannelDim1NudgedUp_NarrowRange
# scale 1/4, zp: 1.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.26, -0.25, -0.24, 63.3]).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel6():
# WithVarsPerChannelDim2NudgedDown_RegularRange
# scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.75]
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.80]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65]).reshape(3).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel7():
# WithVarsPerChannelDim2NudgedDown_NarrowRange
# scale 1/4, zp: 1.4, nudge 1. nudged range [-0.25, 63.5]
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4]).reshape(3).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel8():
# WithVarsPerChannelDim2NudgedUp_RegularRange
# scale 1/4, zp: 0.5, nudge 1. nudged range [-0.25, 63.5]
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]
).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625]).reshape(3).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel9():
# WithVarsPerChannelDim2NudgedUp_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]
).reshape(2, 3).astype(np.float32)
expect = np.array(
[-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375]).reshape(3).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel10():
# WithVarsPerChannelDim4NudgedDown_RegularRange
# scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.7, 63.75, 63.8,
63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.75, 63.75, 63.75,
63.75, 63.75, 63.75, 63.75, 63.75, 63.75]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]
).reshape(4).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel11():
# WithVarsPerChannelDim4NudgedDown_NarrowRange
# scale 1/4, zp: 1.4, nudge 1. nudged range [0.0, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.3, 63.4, 63.5, 63.6,
63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.25, 63.5, 63.5, 63.5,
63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).reshape(4).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel12():
# WithVarsPerChannelDim4NudgedUp_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.4, 63.5, 63.6, 63.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.5, 63.5, 63.5, 63.5,
63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]
).reshape(4).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]
).reshape(4).astype(np.float32)
net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel13():
# WithVarsPerChannelDim4NudgedUp_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.2, 63.25, 63.3, 63.4, 63.5,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.25, 63.25, 63.25, 63.25,
63.25, 63.25, 63.25, 63.25, 63.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]
).reshape(4).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]
).reshape(4).astype(np.float32)
net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel14():
# WithVarsPerChannelDim1NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 7.5, 7.6]).reshape(4).astype(np.float32)
expect = np.array([0.0, 0.0, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel15():
# WithVarsPerChannelDim1NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 7.0, 7.1]).reshape(4).astype(np.float32)
expect = np.array([0.0, 0.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel16():
# WithVarsPerChannelDim1NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, 7.0, 7.1]).reshape(4).astype(np.float32)
expect = np.array([-0.5, -0.5, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel17():
# WithVarsPerChannelDim1NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, 6.5, 6.6]).reshape(4).astype(np.float32)
expect = np.array([-0.5, -0.5, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel18():
# WithVarsPerChannelDim2NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4]).reshape(3).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel19():
# WithVarsPerChannelDim2NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9]).reshape(3).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel20():
# WithVarsPerChannelDim2NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.51, -0.5, -0.24, 0.0, 7.0, 7.1]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, 0.0, 0.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1]).reshape(3).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel21():
# WithVarsPerChannelDim2NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, 0.0, 0.0, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6]).reshape(3).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel22():
# WithVarsPerChannelDim4NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.4, 7.5, 7.7,
7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.5, 7.5, 7.5,
7.5, 7.5, 7.5, 7.5, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel23():
# WithVarsPerChannelDim4NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.8, 6.9, 7.0, 7.1,
7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel24():
# WithVarsPerChannelDim4NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.9, 7.0, 7.1, 7.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel25():
# WithVarsPerChannelDim4NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.4, 6.5, 6.6, 6.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.5, 6.5, 6.5, 6.5,
6.5, 6.5, 6.5, 6.5, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32)
net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

View File

@ -0,0 +1,373 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerChannelGrad(
num_bits=num_bits, narrow_range=narrow_range)
def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
expect = dout
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithVarsPerChannelDim1GradientNudgedDown_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithVarsPerChannelDim1GradientNudgedDown_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithVarsPerChannelDim1GradientNudgedUp_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# WithVarsPerChannelDim1GradientNudgedUp_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsPerChannelDim2GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsPerChannelDim2GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsPerChannelDim2GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsPerChannelDim2GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.25, 63.3]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad10():
# WithVarsPerChannelDim4GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8,
-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad11():
# WithVarsPerChannelDim4GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5,
63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad12():
# WithVarsPerChannelDim4GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6,
-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad13():
# WithVarsPerChannelDim4GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3,
-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

View File

@ -0,0 +1,386 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(Net, self).__init__()
self.fake_quant = Q.FakeQuantPerLayer(num_bits=num_bits,
quant_delay=quant_delay,
symmetric=symmetric,
narrow_range=narrow_range,
training=training)
def construct(self, x, minq, maxq):
return self.fake_quant(x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant1():
# (8, false, 0.0f, 0.0f, TensorShape({2, 3}),
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(2, 3).astype(np.float32)
min_val = np.array([0]).reshape(1).astype(np.float32)
max_val = np.array([0]).reshape(1).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant2():
# 8, false, -10.0f, 53.75f, TensorShape({2, 3}),
# {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f},
# {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
x = np.array([-10.1, -10.0, -9.9, -9.75, 53.75, 53.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.75]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.75, 53.75]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant3():
# WithVarsNoNudging_NarrowRange
x = np.array([-10.1, -10.0, -9.90, -9.75, 53.5, 53.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.5]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.5, 53.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant4():
# WithVarsNudgedDown_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.65]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant5():
# WithVarsNudgedDown_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant6():
# WithVarsNudgedUp_RegularRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant7():
# WithVarsNudgedUp_NarrowRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant8():
# WithVarsNudgedZeroIs255_RegularRange
x = np.array([-63.80, -63.75, -63.70, -63.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.65]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.75, -63.75, -63.75, -63.5, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant9():
# WithVarsNudgedZeroIs255_NarrowRange
x = np.array([-63.6, -63.5, -63.4, -63.25, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.4]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.5, -63.5, -63.5, -63.25, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant10():
# WithVarsNoNudging_4Bits_RegularRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.5, 1.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.5]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.5, 1.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant11():
# WithVarsNoNudging_4Bits_NarrowRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.0, 1.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.0]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.0, 1.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant12():
# WithVarsNudgedDown_4Bits_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([7.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant13():
# WithVarsNudgedDown_4Bits_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([6.9]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant14():
# WithVarsNudgedUp_4Bits_RegularRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 7.0, 7.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant15():
# WithVarsNudgedUp_4Bits_NarrowRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 6.5, 6.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant16():
# WithVarsNudgedZero15_4Bits_RegularRange
x = np.array([-7.6, -7.5, -7.4, -7.2, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-7.3]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.5, -7.5, -7.5, -7.0, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant17():
# WithVarsNudgedZero15_4Bits_NarrowRange
x = np.array([-7.1, -7.0, -6.9, -6.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.8]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.0, -7.0, -7.0, -6.5, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

View File

@ -0,0 +1,221 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerLayerGrad(num_bits=num_bits, narrow_range=narrow_range)
def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithArgsGradient RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithArgsGradient NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithArgsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithArgsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# FakeQuantWithMinMaxVarsGradient
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0]).reshape(1).astype(np.float32)
max_val = np.array([0.0]).reshape(1).astype(np.float32)
expect = dout
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsGradient_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsGradient_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)