forked from mindspore-Ecosystem/mindspore
Add quant ops.
This commit is contained in:
parent
9ae5f96988
commit
37ee246c70
|
@ -176,3 +176,27 @@ def get_bprop_fakequant_with_minmax_per_channel_update(self):
|
|||
return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.ActsULQ)
|
||||
def get_bprop_acts_ulq(self):
|
||||
"""Grad definition for 'ActsULQ' operation"""
|
||||
op = Q.ActsULQInputGrad()
|
||||
op1 = Q.ActULQClampMinGrad()
|
||||
op2 = Q.ActULQClampMaxGrad()
|
||||
def bprop(x, clamp_min, clamp_max, out, dout):
|
||||
dx = op(dout[0], out[1], out[2])
|
||||
dx1 = op1(dout[0], out[1], out[3])
|
||||
dx2 = op2(dout[0], out[2], out[3])
|
||||
return (dx, dx1, dx2)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Q.WtsARQ)
|
||||
def get_bprop_wts_arq(self):
|
||||
"""Grad definition for 'WtsArq' operation"""
|
||||
def bprop(w, w_min, w_max, out, dout):
|
||||
return (dout, zeros_like(w_min), zeros_like(w_max))
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -325,6 +325,11 @@ from .parallel_concat import _parallel_concat_tbe
|
|||
from .adam_apply_one_assign import _adam_apply_one_assign_tbe
|
||||
from .adam_apply_one_with_decay_assign import _adam_apply_one_with_decay_assign_tbe
|
||||
from .ifmr import _ifmr_tbe
|
||||
from .acts_ulq import _acts_ulq_tbe
|
||||
from .acts_ulq_input_grad import _acts_ulq_input_grad_tbe
|
||||
from .act_ulq_clamp_min_grad import _act_ulq_clamp_min_grad_tbe
|
||||
from .act_ulq_clamp_max_grad import _act_ulq_clamp_max_grad_tbe
|
||||
from .wts_arq import _wts_arq_tbe
|
||||
from .fake_quant_with_min_max_vars import _fake_quant_with_min_max_vars_tbe
|
||||
from .fake_quant_with_min_max_vars_gradient import _fake_quant_with_min_max_vars_gradient_tbe
|
||||
from .fake_quant_with_min_max_vars_per_channel import _fake_quant_with_min_max_vars_per_channel_tbe
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ActULQClampMaxGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
act_ulq_clamp_max_grad_op_info = TBERegOp("ActULQClampMaxGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("act_ulq_clamp_max_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("act_ulq_clamp_max_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.input(1, "input_y", False, "required", "all") \
|
||||
.input(2, "input_z", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(act_ulq_clamp_max_grad_op_info)
|
||||
def _act_ulq_clamp_max_grad_tbe():
|
||||
"""ActULQClampMaxGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ActULQClampMinGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
act_ulq_clamp_min_grad_op_info = TBERegOp("ActULQClampMinGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("act_ulq_clamp_min_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("act_ulq_clamp_min_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.input(1, "input_y", False, "required", "all") \
|
||||
.input(2, "input_z", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(act_ulq_clamp_min_grad_op_info)
|
||||
def _act_ulq_clamp_min_grad_tbe():
|
||||
"""ActULQClampMinGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ActsULQ op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acts_ulq_op_info = TBERegOp("ActsULQ") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acts_ulq.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acts_ulq") \
|
||||
.partial_flag(True) \
|
||||
.attr("fixed_min", "optional", "bool", "all") \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "clamp_min", False, "required", "all") \
|
||||
.input(2, "clamp_max", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "clamp_min_mask", False, "required", "all") \
|
||||
.output(2, "clamp_max_mask", False, "required", "all") \
|
||||
.output(3, "x_clamped_loss", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acts_ulq_op_info)
|
||||
def _acts_ulq_tbe():
|
||||
"""ActsULQ TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ActsULQInputGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
acts_ulq_input_grad_op_info = TBERegOp("ActsULQInputGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("acts_ulq_input_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("acts_ulq_input_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "y_grad", False, "required", "all") \
|
||||
.input(1, "clamp_min_mask", False, "required", "all") \
|
||||
.input(2, "clamp_max_mask", False, "required", "all") \
|
||||
.output(0, "x_grad", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.BOOL_Default, DataType.BOOL_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(acts_ulq_input_grad_op_info)
|
||||
def _acts_ulq_input_grad_tbe():
|
||||
"""ActsULQInputGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""WtsARQ op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
wts_arq_op_info = TBERegOp("WtsARQ") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("wts_arq.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("wts_arq") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("offset_flag", "optional", "bool", "all") \
|
||||
.input(0, "w", False, "required", "all") \
|
||||
.input(1, "w_min", False, "required", "all") \
|
||||
.input(2, "w_max", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(wts_arq_op_info)
|
||||
def _wts_arq_tbe():
|
||||
"""WtsARQ TBE register"""
|
||||
return
|
|
@ -1197,3 +1197,205 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
|
|||
def infer_dtype(self, dout_type, x_type):
|
||||
validator.check("dout type", dout_type, "x type", x_type)
|
||||
return dout_type, dout_type
|
||||
|
||||
|
||||
class ActsULQ(PrimitiveWithInfer):
|
||||
"""
|
||||
The ActsULQ(Activation universal learnable quantization).
|
||||
|
||||
Args:
|
||||
fixed_min (bool): whether fix clamp min to zero.
|
||||
num_bits (int): The bits num used for quantize.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
|
||||
- **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x.
|
||||
- **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`.
|
||||
- **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min.
|
||||
- **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max.
|
||||
- **x_clamped_loss** (Tensor) - A tensor of clamped loss.
|
||||
|
||||
Examples:
|
||||
>>> data_type = np.float32
|
||||
>>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type)
|
||||
>>> clamp_max = 0.7 * np.max(x)
|
||||
>>> clamp_min = 0.7 * np.min(x)
|
||||
>>> clamp_max = np.array([clamp_max], dtype=data_type)
|
||||
>>> clamp_min = np.array([clamp_min], dtype=data_type)
|
||||
>>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8)
|
||||
>>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min),
|
||||
Tensor(clamp_max))
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, fixed_min=False, num_bits=8):
|
||||
validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
|
||||
validator.check_value_type("num_bits", num_bits, [int], self.name)
|
||||
validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
|
||||
|
||||
def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
|
||||
"""infer shape of primitive"""
|
||||
validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
|
||||
validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
|
||||
|
||||
x_shape_len = len(x_shape)
|
||||
for i in range(x_shape_len):
|
||||
validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
|
||||
validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
|
||||
|
||||
return x_shape, x_shape, x_shape, x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
|
||||
"""infer dtype of primitive"""
|
||||
valid_types = [mstype.float32, mstype.float16]
|
||||
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"clamp_min": clamp_min_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"clamp_max": clamp_max_dtype}, valid_types, self.name)
|
||||
|
||||
return x_dtype, mstype.bool_, mstype.bool_, x_dtype
|
||||
|
||||
|
||||
class ActsULQInputGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
The ActsULQInputGrad(grad of ActsULQ).
|
||||
|
||||
Inputs:
|
||||
- **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
- **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape):
|
||||
return y_grad_shape
|
||||
|
||||
def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
|
||||
valid_types = [mstype.float32, mstype.float16]
|
||||
validator.check_tensor_type_same({"y_grad": y_grad_type}, valid_types, self.name)
|
||||
return y_grad_type
|
||||
|
||||
|
||||
class ActULQClampMinGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient)
|
||||
|
||||
Inputs:
|
||||
- **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
|
||||
- **clamp_min_mask** - A tensor of mask, only support int8 type.
|
||||
- **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
|
||||
|
||||
Outputs:
|
||||
- **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad".
|
||||
The length of tensor is 1.
|
||||
|
||||
Examples:
|
||||
>>> data_type = np.float32
|
||||
>>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
|
||||
>>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
|
||||
>>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
|
||||
>>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad()
|
||||
>>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_),
|
||||
Tensor(x_clamped_loss))
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, input_x, input_y, input_z):
|
||||
input_x_len = len(input_x)
|
||||
output_shape = []
|
||||
for _ in range(input_x_len):
|
||||
output_shape.append(1)
|
||||
return tuple(output_shape)
|
||||
|
||||
def infer_dtype(self, input_x, input_y, input_z):
|
||||
return input_x
|
||||
|
||||
|
||||
class ActULQClampMaxGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient)
|
||||
|
||||
Inputs:
|
||||
- **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
|
||||
- **clamp_max_mask** - A tensor of mask, only support int8 type.
|
||||
- **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
|
||||
|
||||
Outputs:
|
||||
- **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad".
|
||||
The length of tensor is 1.
|
||||
|
||||
Examples:
|
||||
>>> data_type = np.float32
|
||||
>>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
|
||||
>>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
|
||||
>>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
|
||||
>>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad()
|
||||
>>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_),
|
||||
Tensor(x_clamped_loss))
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_shape(self, input_x, input_y, input_z):
|
||||
input_x_len = len(input_x)
|
||||
output_shape = []
|
||||
for _ in range(input_x_len):
|
||||
output_shape.append(1)
|
||||
return tuple(output_shape)
|
||||
|
||||
def infer_dtype(self, input_x, input_y, input_z):
|
||||
return input_x
|
||||
|
||||
|
||||
class WtsARQ(PrimitiveWithInfer):
|
||||
"""
|
||||
The WtsARQ(Weights Adaptive Range Quantization).
|
||||
|
||||
Args:
|
||||
axes (list): Specify channels for ARQ algorithm.
|
||||
num_bits (int): The bits num used for quantize.
|
||||
offset_flag (bool): Whether use offset for quantize.
|
||||
|
||||
Inputs:
|
||||
- **w** (Tensor) - A Tensor of weights. With float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
- **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`.
|
||||
- **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`.
|
||||
- If axis is [],
|
||||
the shape of scale and offset is :math:`(1, )`.
|
||||
- If axis is [0],
|
||||
the shape of scale and offset is :math:`(w_1, )`.
|
||||
- If axis is [1],
|
||||
the shape of scale and offset is :math:`(w_2, )`.
|
||||
- **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`.
|
||||
|
||||
Examples:
|
||||
>>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
|
||||
>>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False)
|
||||
>>> scale, offset, y = wts_arq(data)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, num_bits, offset_flag):
|
||||
validator.check_value_type("num_bits", num_bits, [int], self.name)
|
||||
validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
|
||||
validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
|
||||
|
||||
def infer_shape(self, w_shape, w_min_shape, w_max_shape):
|
||||
validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
|
||||
validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
|
||||
return w_shape
|
||||
|
||||
def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
|
||||
valid_types = [mstype.float32, mstype.float16]
|
||||
validator.check_tensor_type_same({"w": w_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name)
|
||||
return w_dtype
|
||||
|
|
Loading…
Reference in New Issue