bug fix in fake quant ops

This commit is contained in:
chenzomi 2020-06-08 17:44:15 +08:00
parent eaaf824f18
commit 97a548789a
1 changed files with 17 additions and 3 deletions

View File

@ -15,6 +15,7 @@
"""Operators for quantization."""
import mindspore.context as context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register
@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
narrow_range=False,
training=True):
"""init FakeQuantPerLayer OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
quant_delay=0,
symmetric=False,
narrow_range=False):
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' Attr \'num_bits\' is not support.")
@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
narrow_range=False,
channel_axis=1):
"""init FakeQuantPerChannelGrad Fill"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
outputs=['out'])
@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, channel_axis=0):
"""init correction mul layer"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import correction_mul_grad
self.channel_axis = channel_axis
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
outputs=['dx', 'd_gamma'])
@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")