forked from mindspore-Ecosystem/mindspore
bug fix in fake quant ops
This commit is contained in:
parent
eaaf824f18
commit
97a548789a
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""Operators for quantization."""
|
||||
|
||||
import mindspore.context as context
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
@ -82,6 +83,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|||
narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantPerLayer OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
@ -143,6 +146,8 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
|||
quant_delay=0,
|
||||
symmetric=False,
|
||||
narrow_range=False):
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
@ -222,6 +227,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
training=True,
|
||||
channel_axis=1):
|
||||
"""init FakeQuantPerChannel OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' Attr \'num_bits\' is not support.")
|
||||
|
@ -286,6 +293,8 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
|
|||
narrow_range=False,
|
||||
channel_axis=1):
|
||||
"""init FakeQuantPerChannelGrad Fill"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
@ -454,6 +463,8 @@ class CorrectionMul(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, channel_axis=0):
|
||||
"""init correction mul layer"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import correction_mul
|
||||
self.channel_axis = channel_axis
|
||||
self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
|
||||
outputs=['out'])
|
||||
|
@ -486,6 +497,8 @@ class CorrectionMulGrad(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, channel_axis=0):
|
||||
"""init correction mul layer"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import correction_mul_grad
|
||||
self.channel_axis = channel_axis
|
||||
self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
|
||||
outputs=['dx', 'd_gamma'])
|
||||
|
@ -847,9 +860,8 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
|
|||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True):
|
||||
"""init FakeQuantMinMaxPerLayerUpdate OP"""
|
||||
from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
@ -922,6 +934,8 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
|
|||
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
|
||||
training=True, channel_axis=1):
|
||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
f"For '{self.name}' attr \'num_bits\' is not support.")
|
||||
|
|
Loading…
Reference in New Issue