forked from mindspore-Ecosystem/mindspore
!7209 Add some fake-quant operators
Merge pull request !7209 from jiangzhenguang/add_fake_quant_operator
This commit is contained in:
commit
7b060b2562
|
@ -34,6 +34,31 @@ def get_bprop_fakequant_with_minmax(self):
|
|||
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantWithMinMaxVars)
|
||||
def get_bprop_fakequant_with_minmax_vars(self):
|
||||
"""Generate bprop for FakeQuantWithMinMaxVars for Ascend"""
|
||||
op = Q.FakeQuantWithMinMaxVarsGradient(
|
||||
num_bits=self.num_bits, narrow_range=self.narrow_range)
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
dx = op(dout, x, x_min, x_max)
|
||||
return dx, zeros_like(x_min), zeros_like(x_max)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantWithMinMaxVarsPerChannel)
|
||||
def get_bprop_fakequant_with_minmax_vars_perchannel(self):
|
||||
"""Generate bprop for FakeQuantWithMinMaxVarsPerChannel for Ascend"""
|
||||
op = Q.FakeQuantWithMinMaxVarsPerChannelGradient(
|
||||
num_bits=self.num_bits, narrow_range=self.narrow_range)
|
||||
|
||||
def bprop(x, x_min, x_max, out, dout):
|
||||
dx = op(dout, x, x_min, x_max)
|
||||
return dx, zeros_like(x_min), zeros_like(x_max)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.FakeQuantPerChannel)
|
||||
def get_bprop_fakequant_with_minmax_perchannel(self):
|
||||
|
|
|
@ -321,3 +321,7 @@ from .parallel_concat import _parallel_concat_tbe
|
|||
from .adam_apply_one_assign import _adam_apply_one_assign_tbe
|
||||
from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe
|
||||
from .ifmr import _ifmr_tbe
|
||||
from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe
|
||||
from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe
|
||||
from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe
|
||||
from .fake_quant_with_min_max_vars_per_channel_gradient import _fake_quant_with_min_max_vars_per_channel_gradient_tbe
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxVars op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fake_quant_with_min_max_vars_op_info = TBERegOp("FakeQuantWithMinMaxVars") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_vars.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_vars") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "min", False, "required", "all") \
|
||||
.input(2, "max", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_with_min_max_vars_op_info)
|
||||
def _fake_quant_with_min_max_vars_tbe():
|
||||
"""FakeQuantWithMinMaxVar TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxVars op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fake_quant_with_min_max_vars_gradient_op_info = TBERegOp("FakeQuantWithMinMaxVarsGradient") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_vars_gradient.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_vars_gradient") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.input(0, "gradients", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "min", False, "required", "all") \
|
||||
.input(3, "max", False, "required", "all") \
|
||||
.output(0, "backprops_wrt_x", True, "required", "all") \
|
||||
.output(1, "backprops_wrt_min", True, "required", "all") \
|
||||
.output(2, "backprops_wrt_max", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_with_min_max_vars_gradient_op_info)
|
||||
def _fake_quant_with_min_max_vars_gradient_tbe():
|
||||
"""FakeQuantWithMinMaxVarsGradient TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxVarsPerChannel op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fake_quant_with_min_max_vars_per_channel_op_info = TBERegOp("FakeQuantWithMinMaxVarsPerChannel") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_vars_per_channel.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_vars_per_channel") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "min", False, "required", "all") \
|
||||
.input(2, "max", False, "required", "all") \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_with_min_max_vars_per_channel_op_info)
|
||||
def _fake_quant_with_min_max_vars_per_channel_tbe():
|
||||
"""FakeQuantWithMinMaxVarsPerChannel TBE register"""
|
||||
return
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxVarsPerChannelGradient op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fake_quant_with_min_max_vars_per_channel_gradient_op_info = TBERegOp("FakeQuantWithMinMaxVarsPerChannelGradient") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_vars_per_channel_gradient.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_vars_per_channel_gradient") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.input(0, "gradients", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.input(2, "min", False, "required", "all") \
|
||||
.input(3, "max", False, "required", "all") \
|
||||
.output(0, "backprops_wrt_x", True, "required", "all") \
|
||||
.output(1, "backprops_wrt_min", True, "required", "all") \
|
||||
.output(2, "backprops_wrt_max", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_with_min_max_vars_per_channel_gradient_op_info)
|
||||
def _fake_quant_with_min_max_vars_per_channel_gradient_tbe():
|
||||
"""FakeQuantWithMinMaxVarsPerChannelGradient TBE register"""
|
||||
return
|
|
@ -23,6 +23,10 @@ from ...common import dtype as mstype
|
|||
|
||||
__all__ = ["MinMaxUpdatePerLayer",
|
||||
"MinMaxUpdatePerChannel",
|
||||
"FakeQuantWithMinMaxVars",
|
||||
"FakeQuantWithMinMaxVarsGradient",
|
||||
"FakeQuantWithMinMaxVarsPerChannel",
|
||||
"FakeQuantWithMinMaxVarsPerChannelGradient",
|
||||
"FakeQuantPerLayer",
|
||||
"FakeQuantPerLayerGrad",
|
||||
"FakeQuantPerChannel",
|
||||
|
@ -163,6 +167,237 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
return min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
|
||||
r"""
|
||||
Fake-quantize the input by min and max.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
|
||||
if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
|
||||
range is [1, 2^num_bits-1]. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) - Value of the min range of the input data x.
|
||||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor, the data type and shape of output tensor is the same as input x.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
|
||||
>>> input_tensor, min_tensor, max_tensor)
|
||||
>>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
narrow_range=False):
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
|
||||
def check_broadcast(self, min_shape, input_shape):
|
||||
shape_val = 1
|
||||
for shape in input_shape:
|
||||
shape_val = shape_val * shape
|
||||
if min_shape[0] > 1 and min_shape[0] != shape_val:
|
||||
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
|
||||
self.check_broadcast(min_shape, x_shape)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad of FakeQuantWithMinMaxVars operation.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
|
||||
if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
|
||||
range is [1, 2^num_bits-1]. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
|
||||
- **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) - Value of the min range of the input data x.
|
||||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
|
||||
|
||||
Examples:
|
||||
>>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
|
||||
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
|
||||
>>> (gradients, input_tensor, min_tensor, max_tensor)
|
||||
>>> x_gradient shape: (3, 16, 5, 5) data type: mstype.float32
|
||||
>>> min_gradient shape: (1,) data type: mstype.float32
|
||||
>>> max_gradient shape: (1,) data type: mstype.float32
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
narrow_range=False):
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
|
||||
def check_broadcast(self, min_shape, input_shape):
|
||||
shape_val = 1
|
||||
for shape in input_shape:
|
||||
shape_val = shape_val * shape
|
||||
if min_shape[0] > 1 and min_shape[0] != shape_val:
|
||||
raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
|
||||
self.check_broadcast(min_shape, x_shape)
|
||||
return x_shape, min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
|
||||
return x_type, min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
|
||||
r"""
|
||||
Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
|
||||
if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
|
||||
range is [1, 2^num_bits-1]. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) - Value of the min range of the input data x.
|
||||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- Tensor, the data type and shape of output tensor is the same as input x.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
|
||||
>>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
|
||||
>>> input_tensor, min_tensor, max_tensor)
|
||||
>>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
narrow_range=False):
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
|
||||
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
|
||||
r"""
|
||||
Performs grad of FakeQuantWithMinMaxVars operation.
|
||||
|
||||
Args:
|
||||
num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
|
||||
if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
|
||||
range is [1, 2^num_bits-1]. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
|
||||
- **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
|
||||
- **min** (Tensor) - Value of the min range of the input data x.
|
||||
- **max** (Tensor) - Value of the max range of the input data x.
|
||||
|
||||
Outputs:
|
||||
- **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape date type as input x.
|
||||
- **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape date type as input min.
|
||||
- **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape date type as input max.
|
||||
|
||||
Examples:
|
||||
>>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
|
||||
>>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
|
||||
>>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
|
||||
>>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
|
||||
>>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
|
||||
>>> num_bits=8, narrow_range=False)(
|
||||
>>> gradients, input_tensor, min_tensor, max_tensor)
|
||||
>>> x_gradient shape: (3, 16, 3, 4) data type: mstype.float32
|
||||
>>> min_gradient shape: (4,) data type: mstype.float32
|
||||
>>> max_gradient shape: (4,) data type: mstype.float32
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
num_bits=8,
|
||||
narrow_range=False):
|
||||
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
||||
self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
|
||||
self.narrow_range = validator.check_value_type(
|
||||
'narrow_range', narrow_range, (bool,), self.name)
|
||||
|
||||
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name)
|
||||
validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
|
||||
return x_shape, min_shape, max_shape
|
||||
|
||||
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'min': min_type}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({'max': max_type}, valid_types, self.name)
|
||||
return x_type, min_type, max_type
|
||||
|
||||
|
||||
class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||
r"""
|
||||
Simulates the quantize and dequantize operations in training time.
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops.operations._quant_ops import FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsPerChannel
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
|
@ -1029,6 +1030,18 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[True, False, False], [False, True, True]])),
|
||||
[2, 3], [2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('FakeQuantWithMinMaxVars', {
|
||||
'block': FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32),
|
||||
Tensor(np.array([-6]), mstype.float32),
|
||||
Tensor(np.array([6]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)]}),
|
||||
('FakeQuantWithMinMaxVarsPerChannel', {
|
||||
'block': FakeQuantWithMinMaxVarsPerChannel(num_bits=8, narrow_range=False),
|
||||
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32),
|
||||
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
|
||||
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
|
||||
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}),
|
||||
('Rank', {
|
||||
'block': P.Rank(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
Loading…
Reference in New Issue