forked from mindspore-Ecosystem/mindspore
!19854 clearing the untrusted code of quantization
Merge pull request !19854 from Erpim/master
This commit is contained in:
commit
5cde059ef3
|
@ -42,7 +42,6 @@ const std::vector<size_t> &FakeLearnedScaleQuantPerChannelGradGpuKernel::GetWork
|
|||
bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
||||
if (input_num != 4) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num
|
||||
<< ", but FakeLearnedScaleQuantPerChannelGrad GpuKernel OP needs 4 input.";
|
||||
|
@ -109,7 +108,7 @@ bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Launch(const std::vector<Addr
|
|||
|
||||
if (global_step_ >= quant_delay_) {
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * num_channels_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * kChannelLen,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
CalLSQNudgePerChannel(input, quant_num_, input_alpha, input_quant_max, input_div_alpha, input_quant, neg_trunc_,
|
||||
|
@ -118,7 +117,7 @@ bool FakeLearnedScaleQuantPerChannelGradGpuKernel::Launch(const std::vector<Addr
|
|||
neg_trunc_, num_channels_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * num_channels_,
|
||||
cudaMemcpyAsync(grad_alpha, alpha_no_grad, sizeof(float) * kChannelLen,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Copy gpu memory failed");
|
||||
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
|
||||
|
|
|
@ -36,7 +36,6 @@ const std::vector<size_t> &FakeLearnedScaleQuantPerLayerGradGpuKernel::GetWorksp
|
|||
bool FakeLearnedScaleQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
||||
if (input_num != 4) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num
|
||||
<< ", but FakeLearnedScaleQuantPerLayerGrad GpuKernel OP needs 4 input.";
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
std::vector<AnfNodePtr> *lsq_perlayer_grad_d_outputs) {
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_grad_d_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
|
||||
const auto &lsq_perlayer_grad_inputs = lsq_perlayer_grad_node->inputs();
|
||||
|
@ -58,7 +58,7 @@ void CreateOutputsOfLSQPerLayerGradD(const FuncGraphPtr &graph, const CNodePtr &
|
|||
|
||||
void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perlayer_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perlayer_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *lsq_perlayer_reduce_grad_outputs) {
|
||||
std::vector<AnfNodePtr> *const lsq_perlayer_reduce_grad_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perlayer_reduce_grad_outputs);
|
||||
|
@ -86,7 +86,7 @@ void CreateOutputsOfLSQPerLayerReduceGrad(const FuncGraphPtr &graph, const CNode
|
|||
}
|
||||
|
||||
void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
std::vector<AnfNodePtr> *lsq_perchannel_grad_d_outputs) {
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_grad_d_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
|
||||
const auto &lsq_perchannel_grad_inputs = lsq_perchannel_grad_node->inputs();
|
||||
|
@ -116,7 +116,7 @@ void CreateOutputsOfLSQPerChannelGradD(const FuncGraphPtr &graph, const CNodePtr
|
|||
|
||||
void CreateOutputsOfLSQPerChannelReduceGrad(const FuncGraphPtr &graph, const CNodePtr &lsq_perchannel_grad_node,
|
||||
const std::vector<AnfNodePtr> &lsq_perchannel_grad_d_outputs,
|
||||
std::vector<AnfNodePtr> *lsq_perchannel_reduce_grad_outputs) {
|
||||
std::vector<AnfNodePtr> *const lsq_perchannel_reduce_grad_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(lsq_perchannel_reduce_grad_outputs);
|
||||
|
|
|
@ -30,7 +30,7 @@ from ...nn.layer import quant
|
|||
from ...ops import functional as F
|
||||
from ..common import QuantDtype
|
||||
from .quantizer import Quantizer, OptimizeOption
|
||||
from .quant_utils import compute_KL_threshold
|
||||
from .quant_utils import compute_kl_threshold
|
||||
|
||||
|
||||
__all__ = ["QuantizationAwareTraining", "create_quant_config"]
|
||||
|
@ -281,7 +281,8 @@ class QuantizationAwareTraining(Quantizer):
|
|||
mode=self.mode)
|
||||
self.eps = 1e-5
|
||||
|
||||
def _convert_op_name(self, name):
|
||||
@staticmethod
|
||||
def _convert_op_name(name):
|
||||
pattern = re.compile(r'([A-Z]{1})')
|
||||
name_new = re.sub(pattern, r'_\1', name).lower()
|
||||
if name_new[0] == '_':
|
||||
|
@ -382,7 +383,7 @@ class QuantizationAwareTraining(Quantizer):
|
|||
scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
|
||||
np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
|
||||
subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
|
||||
min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
|
||||
min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype)
|
||||
self.quant_config = self.quant_config._replace(
|
||||
weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
|
||||
|
||||
|
@ -485,7 +486,7 @@ class QuantizationAwareTraining(Quantizer):
|
|||
scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
|
||||
np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
|
||||
subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
|
||||
min_init, max_init = self._KL_init(subcell_weight_para, self.weight_dtype)
|
||||
min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype)
|
||||
self.quant_config = self.quant_config._replace(
|
||||
weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
|
||||
|
||||
|
@ -533,16 +534,16 @@ class QuantizationAwareTraining(Quantizer):
|
|||
quant_dtype=self.act_dtype)
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
|
||||
def _KL_init(self, subcell_weight_para, weight_dtype):
|
||||
def _kl_init(self, subcell_weight_para, weight_dtype):
|
||||
"""
|
||||
Calculate the value of max_init and min_init with compute_KL_threshold.
|
||||
Calculate the value of max_init and min_init with compute_kl_threshold.
|
||||
"""
|
||||
if self.weight_channel:
|
||||
max_init = [compute_KL_threshold(weight_para_each, weight_dtype)
|
||||
max_init = [compute_kl_threshold(weight_para_each, weight_dtype)
|
||||
for weight_para_each in subcell_weight_para]
|
||||
min_init = [-x for x in max_init]
|
||||
else:
|
||||
max_init = [compute_KL_threshold(subcell_weight_para, weight_dtype)]
|
||||
max_init = [compute_kl_threshold(subcell_weight_para, weight_dtype)]
|
||||
min_init = [-x for x in max_init]
|
||||
return min_init, max_init
|
||||
|
||||
|
@ -567,15 +568,17 @@ class QuantizationAwareTraining(Quantizer):
|
|||
raise ValueError("The `_set_mixed_bits` function is currently only valid for `LEARNED_SCALE` "
|
||||
"optimize_option.")
|
||||
|
||||
self.quantizable_idx = []
|
||||
quantizable_idx = []
|
||||
pass_cell = None
|
||||
for i, cell_and_name in enumerate(network.cells_and_names()):
|
||||
cell = cell_and_name[1]
|
||||
if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell:
|
||||
self.quantizable_idx.append(i)
|
||||
quantizable_idx.append(i)
|
||||
|
||||
assert len(self.quantizable_idx) == len(strategy)
|
||||
quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(self.quantizable_idx, strategy)}
|
||||
if len(quantizable_idx) != len(strategy):
|
||||
raise ValueError("The dimension of quantifiable layers is not consistent with that of strategy.")
|
||||
|
||||
quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(quantizable_idx, strategy)}
|
||||
type_map = {
|
||||
QuantDtype.INT2.num_bits: QuantDtype.INT2,
|
||||
QuantDtype.INT3.num_bits: QuantDtype.INT3,
|
||||
|
@ -587,7 +590,7 @@ class QuantizationAwareTraining(Quantizer):
|
|||
}
|
||||
for i, cell_and_name in enumerate(network.cells_and_names()):
|
||||
cell = cell_and_name[1]
|
||||
if i not in self.quantizable_idx:
|
||||
if i not in quantizable_idx:
|
||||
continue
|
||||
else:
|
||||
if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)):
|
||||
|
@ -598,7 +601,7 @@ class QuantizationAwareTraining(Quantizer):
|
|||
scale_factor = (cell.conv.gamma.data.asnumpy() /
|
||||
np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps))
|
||||
subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
|
||||
min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
|
||||
min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype)
|
||||
cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
|
||||
min_init=min_init,
|
||||
max_init=max_init)
|
||||
|
@ -608,7 +611,7 @@ class QuantizationAwareTraining(Quantizer):
|
|||
scale_factor = (cell.dense.gamma.data.asnumpy() /
|
||||
np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps))
|
||||
subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
|
||||
min_init, max_init = self._KL_init(subcell_weight_para, cell.weight_dtype)
|
||||
min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype)
|
||||
cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
|
||||
min_init=min_init,
|
||||
max_init=max_init)
|
||||
|
|
|
@ -222,7 +222,7 @@ def without_fold_batchnorm(weight, cell_quant):
|
|||
return weight, bias
|
||||
|
||||
|
||||
def compute_KL_threshold(data, bitwidth):
|
||||
def compute_kl_threshold(data, bitwidth):
|
||||
r"""
|
||||
Using KL-J Distance to calculate the clip threshold.
|
||||
|
||||
|
@ -232,20 +232,18 @@ def compute_KL_threshold(data, bitwidth):
|
|||
Outputs:
|
||||
Tensor with Shape 1. Threshold to calculate the data.
|
||||
"""
|
||||
bitwidth = bitwidth.num_bits
|
||||
data_min = 0
|
||||
data_max = np.abs(data).max()
|
||||
if data_max < 1e-5:
|
||||
return 1e-5
|
||||
hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(data_min, data_max), density=True)
|
||||
hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True)
|
||||
# For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the
|
||||
# largest size, turn to use the default bins config.
|
||||
largest_bin_size = 1024
|
||||
if hist.shape[0] > largest_bin_size:
|
||||
hist, bin_edges = np.histogram(np.abs(data), range=(data_min, data_max), density=True)
|
||||
hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True)
|
||||
hist = hist / np.sum(hist)
|
||||
cumsum = np.cumsum(hist)
|
||||
bit_pow_range = pow(2, int(bitwidth) - 1)
|
||||
bit_pow_range = pow(2, int(bitwidth.num_bits) - 1)
|
||||
threshold = []
|
||||
scaling_factor = []
|
||||
kl = []
|
||||
|
@ -349,11 +347,11 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_param
|
|||
subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
|
||||
|
||||
if cell.fake_quant_weight.per_channel:
|
||||
max_init = [compute_KL_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype)
|
||||
max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype)
|
||||
for weight_para_each in subcell_weight_para]
|
||||
min_init = [-x for x in max_init]
|
||||
else:
|
||||
max_init = [compute_KL_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)]
|
||||
max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)]
|
||||
min_init = [-x for x in max_init]
|
||||
|
||||
cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype,
|
||||
|
|
|
@ -57,4 +57,8 @@ class Quantizer(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def quantize(self, network):
|
||||
pass
|
||||
"""
|
||||
Quant API to convert input network to a quantization aware training network
|
||||
Args:
|
||||
network (Cell): network to be quantized.
|
||||
"""
|
||||
|
|
|
@ -1535,6 +1535,7 @@ class ActQuant(_QuantActivation):
|
|||
quant_dtype=quant_dtype,
|
||||
neg_trunc=self.neg_trunc,
|
||||
narrow_range=self.narrow_range)
|
||||
|
||||
def construct(self, x):
|
||||
if self.fake_before:
|
||||
x = self.fake_quant_act_before(x)
|
||||
|
|
|
@ -125,7 +125,7 @@ def get_bprop_batchnorm_fold2(self):
|
|||
|
||||
|
||||
@bprop_getters.register(Q.BatchNormFoldD)
|
||||
def get_bprop_BatchNormFold(self):
|
||||
def get_bprop_batchnormfold(self):
|
||||
"""Generate bprop for BatchNormFold for Ascend"""
|
||||
op = Q.BatchNormFoldGradD(self.epsilon, self.is_training, self.freeze_bn)
|
||||
|
||||
|
@ -137,7 +137,7 @@ def get_bprop_BatchNormFold(self):
|
|||
|
||||
|
||||
@bprop_getters.register(P.BNTrainingReduce)
|
||||
def get_bprop_BNTrainingReduce(self):
|
||||
def get_bprop_bn_training_reduce(self):
|
||||
"""Generate bprop for BNTrainingReduce for Ascend"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
|
|
|
@ -57,6 +57,43 @@ def _batchnorm_fold_tbe():
|
|||
return
|
||||
|
||||
|
||||
def _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon):
|
||||
"""_batchnorm_fold_compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict,
|
||||
dict, dict, dict, dict, dict, dict, dict,
|
||||
float, float, bool, int, str, str)
|
||||
|
@ -108,39 +145,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
|
||||
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
|
||||
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
|
||||
res = _batchnorm_fold_compute(x_input, x_sum, x_square_sum, mean, variance, momentum, epsilon)
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
config = {"name": kernel_name,
|
||||
|
|
|
@ -21,6 +21,7 @@ from te.platform.cce_build import build_config
|
|||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
|
@ -81,9 +82,8 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
@ -100,7 +100,7 @@ def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name
|
|||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
||||
return
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
|
||||
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
|
||||
with build_config:
|
||||
tvm.build(sch, tensor_list, "cce", name=kernel_name)
|
||||
|
|
|
@ -95,12 +95,12 @@ def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
|
|||
dtype_mean = d_batch_mean.get("dtype").lower()
|
||||
if format_data == "NC1HWC0":
|
||||
if len(shape_x) != 5:
|
||||
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 5D"
|
||||
"when input format is NC1HWC0")
|
||||
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||
elif format_data == "NCHW":
|
||||
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||
raise RuntimeError("batchnorm_fold grad only support shape 2D to 4D")
|
||||
if shape_x[1] != shape_mean[0]:
|
||||
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||
"be equal to the second axis of shape_x")
|
||||
|
|
|
@ -66,9 +66,8 @@ def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correctio
|
|||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
if not inp_dtype in ["float16", "float32"]:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
|
|
@ -71,10 +71,9 @@ def fake_learned_scale_quant_perchannel_compute(input_data, alpha_data, quant_ma
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, int, str)
|
||||
def fake_learned_scale_quant_perchannel(input_x, alpha, quant_max, out, neg_trunc, channel_axis,
|
||||
kernel_name="fake_learned_scale_quant_perchannel"):
|
||||
"""FakeLearnedScaleQuantPerChannel"""
|
||||
def fake_learned_scale_quant_perchannel_param(input_x, alpha, quant_max, channel_axis,
|
||||
kernel_name="fake_learned_scale_quant_perchannel"):
|
||||
"""Get and check FakeLearnedScaleQuantPerChannel parameters"""
|
||||
input_shape = input_x.get("shape")
|
||||
input_x_shape_ = input_x.get("ori_shape")
|
||||
input_x_format = input_x.get("format")
|
||||
|
@ -113,6 +112,16 @@ def fake_learned_scale_quant_perchannel(input_x, alpha, quant_max, out, neg_trun
|
|||
input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype)
|
||||
alpha_data = tvm.placeholder(shape_c, name="alpha_data", dtype=alpha_dtype)
|
||||
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
|
||||
return input_data, alpha_data, quant_max_data
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, int, str)
|
||||
def fake_learned_scale_quant_perchannel(input_x, alpha, quant_max, out, neg_trunc, channel_axis,
|
||||
kernel_name="fake_learned_scale_quant_perchannel"):
|
||||
"""FakeLearnedScaleQuantPerChannel"""
|
||||
input_data, alpha_data, quant_max_data = \
|
||||
fake_learned_scale_quant_perchannel_param(input_x, alpha, quant_max, channel_axis, kernel_name)
|
||||
|
||||
res = fake_learned_scale_quant_perchannel_compute(input_data, alpha_data, quant_max_data, neg_trunc, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
|
|
|
@ -58,6 +58,36 @@ def _fake_learned_scale_quant_perchannel_grad_d_tbe():
|
|||
return
|
||||
|
||||
|
||||
def _sign_function(dtype, input_div_alpha):
|
||||
"""sign function imp"""
|
||||
if dtype == "float32":
|
||||
data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype)
|
||||
elif dtype == "float16":
|
||||
data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype)
|
||||
else:
|
||||
data_min = tvm.const(1, dtype=dtype)
|
||||
neg_data_min = tvm.const(-1, dtype=dtype)
|
||||
vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min)
|
||||
vmin = te.lang.cce.vmins(vmax, data_min)
|
||||
if dtype == "float32":
|
||||
# max num of float32 is 2**126
|
||||
max_support_fp32 = tvm.const(2 ** 62, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32)
|
||||
res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32)
|
||||
sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype))
|
||||
elif dtype == "float16":
|
||||
# max num of float16 is 2**24
|
||||
# but cce can only support 2**12, so use 12/12 to adaptor 24
|
||||
max_support_fp16 = tvm.const(2 ** 12, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16)
|
||||
sign = te.lang.cce.vmuls(res_mul1, max_support_fp16)
|
||||
else:
|
||||
sign = vmin
|
||||
return sign
|
||||
|
||||
|
||||
@fusion_manager.register("fake_learned_scale_quant_perchannel_grad_d")
|
||||
def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_data, quant_max_data, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
|
||||
|
@ -86,7 +116,6 @@ def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_d
|
|||
tensor_one = tvm.const(1.0, input_div_alpha.dtype)
|
||||
tensor_one = te.lang.cce.broadcast(tensor_one, shape)
|
||||
|
||||
#out_of_bounds = te.lang.cce.vcmpsel(te.lang.cce.vabs(input_div_alpha), 1.0, 'gt', 1.0, 0.0)
|
||||
out_of_upper_bounds = te.lang.cce.vcmpsel(input_div_alpha, 1.0, 'gt', 1.0, 0.0)
|
||||
if neg_trunc:
|
||||
out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, 0.0, 'lt', 1.0, 0.0)
|
||||
|
@ -96,32 +125,7 @@ def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_d
|
|||
|
||||
dx = te.lang.cce.vmul(dx, te.lang.cce.vsub(tensor_one, out_of_bounds))
|
||||
|
||||
# sign function imp
|
||||
if dtype == "float32":
|
||||
data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype)
|
||||
elif dtype == "float16":
|
||||
data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype)
|
||||
else:
|
||||
data_min = tvm.const(1, dtype=dtype)
|
||||
neg_data_min = tvm.const(-1, dtype=dtype)
|
||||
vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min)
|
||||
vmin = te.lang.cce.vmins(vmax, data_min)
|
||||
if dtype == "float32":
|
||||
# max num of float32 is 2**126
|
||||
max_support_fp32 = tvm.const(2 ** 62, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32)
|
||||
res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32)
|
||||
sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype))
|
||||
elif dtype == "float16":
|
||||
# max num of float16 is 2**24
|
||||
# but cce can only support 2**12, so use 12/12 to adaptor 24
|
||||
max_support_fp16 = tvm.const(2 ** 12, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16)
|
||||
sign = te.lang.cce.vmuls(res_mul1, max_support_fp16)
|
||||
else:
|
||||
sign = vmin
|
||||
sign = _sign_function(dtype, input_div_alpha)
|
||||
|
||||
# The following lines are equivalent to :
|
||||
# dalpha_each = dout * sign if out of bounds
|
||||
|
@ -136,10 +140,9 @@ def fake_learned_scale_quant_perchannel_grad_d_compute(dout, input_data, alpha_d
|
|||
return [dx, dalpha_each]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, int, str)
|
||||
def fake_learned_scale_quant_perchannel_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
|
||||
channel_axis, kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
|
||||
"""FakeLearnedScaleQuantPerChannelGradD"""
|
||||
def fake_learned_scale_quant_perchannel_grad_d_param(input_x, alpha, quant_max, channel_axis,
|
||||
kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
|
||||
"""Get and check FakeLearnedScaleQuantPerChannelGradD parameters"""
|
||||
input_shape = input_x.get("shape")
|
||||
input_x_shape_ = input_x.get("ori_shape")
|
||||
input_x_format = input_x.get("format")
|
||||
|
@ -179,6 +182,16 @@ def fake_learned_scale_quant_perchannel_grad_d(dout, input_x, alpha, quant_max,
|
|||
input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype)
|
||||
alpha_data = tvm.placeholder(shape_c, name="alpha_data", dtype=alpha_dtype)
|
||||
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
|
||||
return dout_data, input_data, alpha_data, quant_max_data
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, int, str)
|
||||
def fake_learned_scale_quant_perchannel_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
|
||||
channel_axis, kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
|
||||
"""FakeLearnedScaleQuantPerChannelGradD"""
|
||||
dout_data, input_data, alpha_data, quant_max_data = \
|
||||
fake_learned_scale_quant_perchannel_grad_d_param(input_x, alpha, quant_max, channel_axis, kernel_name)
|
||||
|
||||
res = fake_learned_scale_quant_perchannel_grad_d_compute(dout_data, input_data, alpha_data, quant_max_data,
|
||||
neg_trunc, kernel_name)
|
||||
|
||||
|
|
|
@ -71,10 +71,9 @@ def fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, str)
|
||||
def fake_learned_scale_quant_perlayer(input_x, alpha, quant_max, out, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer"):
|
||||
"""FakeLearnedScaleQuantPerLayer"""
|
||||
def fake_learned_scale_quant_perlayer_param(input_x, alpha, quant_max,
|
||||
kernel_name="fake_learned_scale_quant_perlayer"):
|
||||
"""Get and check FakeLearnedScaleQuantPerLayer parameters"""
|
||||
input_shape = input_x.get("shape")
|
||||
input_dtype = input_x.get("dtype")
|
||||
alpha_shape = alpha.get("ori_shape")
|
||||
|
@ -105,6 +104,16 @@ def fake_learned_scale_quant_perlayer(input_x, alpha, quant_max, out, neg_trunc,
|
|||
input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype)
|
||||
alpha_data = tvm.placeholder(alpha_shape, name="alpha_data", dtype=alpha_dtype)
|
||||
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
|
||||
return input_data, alpha_data, quant_max_data
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, str)
|
||||
def fake_learned_scale_quant_perlayer(input_x, alpha, quant_max, out, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer"):
|
||||
"""FakeLearnedScaleQuantPerLayer"""
|
||||
input_data, alpha_data, quant_max_data = \
|
||||
fake_learned_scale_quant_perlayer_param(input_x, alpha, quant_max, kernel_name)
|
||||
|
||||
res = fake_learned_scale_quant_perlayer_compute(input_data, alpha_data, quant_max_data, neg_trunc, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
|
|
|
@ -59,6 +59,36 @@ def _fake_learned_scale_quant_perlayer_grad_d_tbe():
|
|||
return
|
||||
|
||||
|
||||
def _sign_function(dtype, input_div_alpha):
|
||||
"""sign function imp"""
|
||||
if dtype == "float32":
|
||||
data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype)
|
||||
elif dtype == "float16":
|
||||
data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype)
|
||||
else:
|
||||
data_min = tvm.const(1, dtype=dtype)
|
||||
neg_data_min = tvm.const(-1, dtype=dtype)
|
||||
vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min)
|
||||
vmin = te.lang.cce.vmins(vmax, data_min)
|
||||
if dtype == "float32":
|
||||
# max num of float32 is 2**126
|
||||
max_support_fp32 = tvm.const(2 ** 62, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32)
|
||||
res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32)
|
||||
sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype))
|
||||
elif dtype == "float16":
|
||||
# max num of float16 is 2**24
|
||||
# but cce can only support 2**12, so use 12/12 to adaptor 24
|
||||
max_support_fp16 = tvm.const(2 ** 12, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16)
|
||||
sign = te.lang.cce.vmuls(res_mul1, max_support_fp16)
|
||||
else:
|
||||
sign = vmin
|
||||
return sign
|
||||
|
||||
|
||||
@fusion_manager.register("fake_learned_scale_quant_perlayer_grad_d")
|
||||
def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_data, quant_max_data, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
|
||||
|
@ -87,7 +117,6 @@ def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_dat
|
|||
tensor_one = tvm.const(1.0, input_div_alpha.dtype)
|
||||
tensor_one = te.lang.cce.broadcast(tensor_one, shape)
|
||||
|
||||
#out_of_bounds = te.lang.cce.vcmpsel(te.lang.cce.vabs(input_div_alpha), 1.0, 'gt', 1.0, 0.0)
|
||||
out_of_upper_bounds = te.lang.cce.vcmpsel(input_div_alpha, 1.0, 'gt', 1.0, 0.0)
|
||||
if neg_trunc:
|
||||
out_of_lower_bounds = te.lang.cce.vcmpsel(input_div_alpha, 0.0, 'lt', 1.0, 0.0)
|
||||
|
@ -97,32 +126,7 @@ def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_dat
|
|||
|
||||
dx = te.lang.cce.vmul(dx, te.lang.cce.vsub(tensor_one, out_of_bounds))
|
||||
|
||||
# sign function imp
|
||||
if dtype == "float32":
|
||||
data_min = tvm.const(SCALAR_MIN_FP32, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP32, dtype=dtype)
|
||||
elif dtype == "float16":
|
||||
data_min = tvm.const(SCALAR_MIN_FP16, dtype=dtype)
|
||||
neg_data_min = tvm.const(NEG_SCALAR_MIN_FP16, dtype=dtype)
|
||||
else:
|
||||
data_min = tvm.const(1, dtype=dtype)
|
||||
neg_data_min = tvm.const(-1, dtype=dtype)
|
||||
vmax = te.lang.cce.vmaxs(input_div_alpha, neg_data_min)
|
||||
vmin = te.lang.cce.vmins(vmax, data_min)
|
||||
if dtype == "float32":
|
||||
# max num of float32 is 2**126
|
||||
max_support_fp32 = tvm.const(2 ** 62, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp32)
|
||||
res_mul2 = te.lang.cce.vmuls(res_mul1, max_support_fp32)
|
||||
sign = te.lang.cce.vmuls(res_mul2, tvm.const(2 ** 2, dtype=dtype))
|
||||
elif dtype == "float16":
|
||||
# max num of float16 is 2**24
|
||||
# but cce can only support 2**12, so use 12/12 to adaptor 24
|
||||
max_support_fp16 = tvm.const(2 ** 12, dtype=dtype)
|
||||
res_mul1 = te.lang.cce.vmuls(vmin, max_support_fp16)
|
||||
sign = te.lang.cce.vmuls(res_mul1, max_support_fp16)
|
||||
else:
|
||||
sign = vmin
|
||||
sign = _sign_function(dtype, input_div_alpha)
|
||||
|
||||
# The following lines are equivalent to :
|
||||
# dalpha_each = dout * sign if out of bounds
|
||||
|
@ -137,10 +141,9 @@ def fake_learned_scale_quant_perlayer_grad_d_compute(dout, input_data, alpha_dat
|
|||
return [dx, dalpha_each]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str)
|
||||
def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
|
||||
"""FakeLearnedScaleQuantPerLayerGradD"""
|
||||
def fake_learned_scale_quant_perlayer_grad_d_param(input_x, alpha, quant_max,
|
||||
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
|
||||
"""Get and check FakeLearnedScaleQuantPerLayerGradD parameters"""
|
||||
input_shape = input_x.get("shape")
|
||||
input_dtype = input_x.get("dtype")
|
||||
alpha_shape = alpha.get("ori_shape")
|
||||
|
@ -172,6 +175,15 @@ def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx
|
|||
input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype)
|
||||
alpha_data = tvm.placeholder(alpha_shape, name="alpha_data", dtype=alpha_dtype)
|
||||
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
|
||||
return dout_data, input_data, alpha_data, quant_max_data
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str)
|
||||
def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
|
||||
"""FakeLearnedScaleQuantPerLayerGradD"""
|
||||
dout_data, input_data, alpha_data, quant_max_data = \
|
||||
fake_learned_scale_quant_perlayer_grad_d_param(input_x, alpha, quant_max, kernel_name)
|
||||
|
||||
res = fake_learned_scale_quant_perlayer_grad_d_compute(dout_data, input_data, alpha_data, quant_max_data,
|
||||
neg_trunc, kernel_name)
|
||||
|
||||
|
|
|
@ -86,11 +86,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
def fake_quant_perchannel_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""Get and check fake_quant_perchannel parameters"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -120,15 +118,25 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel(x, min_val, max_val, y,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel"):
|
||||
"""FakeQuantPerChannel"""
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
|
|
|
@ -110,11 +110,9 @@ def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, qua
|
|||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""Get and check FakeQuantPerChannelGrad parameters"""
|
||||
x_shape = x.get("shape")
|
||||
x_shape_ = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
|
@ -144,6 +142,18 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str)
|
||||
def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||
symmetric, narrow_range, num_bits, channel_axis,
|
||||
kernel_name="fake_quant_perchannel_grad"):
|
||||
"""FakeQuantPerChannelGrad"""
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
|
@ -153,10 +163,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
x_shape, shape_c, x_dtype = fake_quant_perchannel_grad_param(x, min_val, max_val,
|
||||
channel_axis, kernel_name)
|
||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
|
|
|
@ -82,7 +82,6 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
kernel_name="minmax_update_perchannel"):
|
||||
"""MinMaxUpdatePerChannel op"""
|
||||
x_shape = x.get("ori_shape")
|
||||
x_format = x.get("format")
|
||||
x_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
|
|
|
@ -6,4 +6,4 @@
|
|||
bprop.10:x*
|
||||
bprop.10:out*
|
||||
bprop.10:dout2
|
||||
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
|
@ -8,4 +8,4 @@
|
|||
bprop.2:x*
|
||||
bprop.2:out*
|
||||
bprop.2:dout2
|
||||
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
||||
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3565f906930f68ca2413e9ad958d105e129e717cd183b95d11d65a8b0b030fc0d65c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
|
Loading…
Reference in New Issue