!19866 compression warnning clean
Merge pull request !19866 from zhang_sss/r1.3_warnning_clean
This commit is contained in:
commit
13b94934c5
|
@ -16,7 +16,4 @@
|
||||||
Compression common module.
|
Compression common module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .constant import *
|
from .constant import QuantDtype
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
__all__.extend(constant.__all__)
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
|
||||||
|
|
||||||
__all__ = ["ExportToQuantInferNetwork"]
|
__all__ = ["ExportToQuantInferNetwork"]
|
||||||
|
|
||||||
|
|
||||||
class QuantBlock(Cell):
|
class QuantBlock(Cell):
|
||||||
r"""
|
r"""
|
||||||
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
||||||
|
@ -180,6 +181,7 @@ class QuantMindirBlock(Cell):
|
||||||
s += f', activation={self.activation}'
|
s += f', activation={self.activation}'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
class ExportToQuantInferNetwork:
|
class ExportToQuantInferNetwork:
|
||||||
"""
|
"""
|
||||||
Convert quantization aware network to infer network.
|
Convert quantization aware network to infer network.
|
||||||
|
@ -199,18 +201,13 @@ class ExportToQuantInferNetwork:
|
||||||
|
|
||||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||||
self.input_scale = 1 / std_dev
|
|
||||||
self.input_zero_point = round(mean)
|
|
||||||
self.data_type = mstype.int8
|
self.data_type = mstype.int8
|
||||||
self.network = copy.deepcopy(network)
|
self.network = copy.deepcopy(network)
|
||||||
self.network_bk = copy.deepcopy(network)
|
|
||||||
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
|
|
||||||
self.get_inputs_table(inputs)
|
self.get_inputs_table(inputs)
|
||||||
self.mean = mean
|
self.mean = mean
|
||||||
self.std_dev = std_dev
|
self.std_dev = std_dev
|
||||||
self.is_mindir = is_mindir
|
self.is_mindir = is_mindir
|
||||||
self.upcell = None
|
self.upcell = None
|
||||||
self.upname = None
|
|
||||||
|
|
||||||
def get_inputs_table(self, inputs):
|
def get_inputs_table(self, inputs):
|
||||||
"""Get the input quantization parameters of quantization cell for quant export."""
|
"""Get the input quantization parameters of quantization cell for quant export."""
|
||||||
|
@ -220,6 +217,7 @@ class ExportToQuantInferNetwork:
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Start to convert."""
|
"""Start to convert."""
|
||||||
|
self.network_bk = copy.deepcopy(self.network)
|
||||||
self.network.update_cell_prefix()
|
self.network.update_cell_prefix()
|
||||||
network = self.network
|
network = self.network
|
||||||
if isinstance(network, _AddFakeQuantInput):
|
if isinstance(network, _AddFakeQuantInput):
|
||||||
|
@ -229,7 +227,36 @@ class ExportToQuantInferNetwork:
|
||||||
|
|
||||||
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
|
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
|
||||||
"""convert network's quant subcell to deploy subcell"""
|
"""convert network's quant subcell to deploy subcell"""
|
||||||
# Calculate the scale and zero point
|
scale_a_in, zp_a_in, scale_w, zp_w, param_dict = self.__get_quant_param(cell_core, fake_quant_a_out)
|
||||||
|
|
||||||
|
# Build the `Quant` `Dequant` op.
|
||||||
|
# Quant only support perlayer version. Need check here.
|
||||||
|
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
|
||||||
|
scale_deq = self.__get_dequant_scale(scale_a_in, scale_w)
|
||||||
|
dequant_op = inner.Dequant()
|
||||||
|
|
||||||
|
if isinstance(activation, _AddFakeQuantAfterSubCell):
|
||||||
|
activation = activation.subcell
|
||||||
|
elif hasattr(activation, "get_origin"):
|
||||||
|
activation = activation.get_origin()
|
||||||
|
|
||||||
|
# get op
|
||||||
|
if isinstance(cell_core, quant.DenseQuant):
|
||||||
|
op_core = P.MatMul()
|
||||||
|
else:
|
||||||
|
op_core = cell_core.conv
|
||||||
|
|
||||||
|
# get the `weight` and `bias`
|
||||||
|
weight, bias, weight_b, bias_b = self.__get_weight_bias(cell_core, scale_a_in, scale_w, zp_w)
|
||||||
|
|
||||||
|
if self.is_mindir:
|
||||||
|
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||||
|
else:
|
||||||
|
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||||
|
return block
|
||||||
|
|
||||||
|
def __get_quant_param(self, cell_core, fake_quant_a_out):
|
||||||
|
"""get parameter for quant block"""
|
||||||
w_minq_name = cell_core.fake_quant_weight.minq.name
|
w_minq_name = cell_core.fake_quant_weight.minq.name
|
||||||
w_maxq_name = cell_core.fake_quant_weight.maxq.name
|
w_maxq_name = cell_core.fake_quant_weight.maxq.name
|
||||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||||
|
@ -262,7 +289,7 @@ class ExportToQuantInferNetwork:
|
||||||
_, minq_name = info
|
_, minq_name = info
|
||||||
if minq_name == 'input':
|
if minq_name == 'input':
|
||||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||||
self.input_scale, self.input_zero_point, 'None', 'None'
|
(1 / self.std_dev), round(self.mean), 'None', 'None'
|
||||||
else:
|
else:
|
||||||
fake_quant_a_in_prefix = minq_name[:-5]
|
fake_quant_a_in_prefix = minq_name[:-5]
|
||||||
cells = self.network_bk.cells_and_names()
|
cells = self.network_bk.cells_and_names()
|
||||||
|
@ -270,26 +297,34 @@ class ExportToQuantInferNetwork:
|
||||||
if cell[0].endswith(fake_quant_a_in_prefix):
|
if cell[0].endswith(fake_quant_a_in_prefix):
|
||||||
fake_quant_a_in = cell[1]
|
fake_quant_a_in = cell[1]
|
||||||
break
|
break
|
||||||
|
|
||||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
|
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
|
||||||
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
|
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
|
||||||
else:
|
else:
|
||||||
# skip quant layer
|
# skip quant layer
|
||||||
scale_a_in, zp_a_in = 1.0, 0.0
|
scale_a_in, zp_a_in = 1.0, 0.0
|
||||||
|
return scale_a_in, zp_a_in, scale_w, zp_w, param_dict
|
||||||
|
|
||||||
# Build the `Quant` `Dequant` op.
|
@staticmethod
|
||||||
# Quant only support perlayer version. Need check here.
|
def __get_dequant_scale(scale_a_in, scale_w):
|
||||||
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
|
"""Get dequant scale"""
|
||||||
scale_deq = scale_a_in * scale_w
|
scale_deq = scale_a_in * scale_w
|
||||||
dequant_op = inner.Dequant()
|
|
||||||
|
|
||||||
if isinstance(activation, _AddFakeQuantAfterSubCell):
|
# fuse parameter
|
||||||
activation = activation.subcell
|
# |--------|47:40|--------|39:32|--------|31:0|
|
||||||
elif hasattr(activation, "get_origin"):
|
# offset_w [8] shift_N [8] deq_scale [32]
|
||||||
activation = activation.get_origin()
|
float32_deq_scale = scale_deq.astype(np.float32)
|
||||||
|
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
|
||||||
|
scale_length = scale_deq.size # channel
|
||||||
|
dequant_param = np.zeros(scale_length, dtype=np.uint64)
|
||||||
|
for index in range(scale_length):
|
||||||
|
dequant_param[index] += uint32_deq_scale[index]
|
||||||
|
scale_deq = Tensor(dequant_param, mstype.uint64)
|
||||||
|
return scale_deq
|
||||||
|
|
||||||
# get the `weight` and `bias`
|
def __get_weight_bias(self, cell_core, scale_a_in, scale_w, zp_w):
|
||||||
|
"""Get weight and bias for quantizaiton"""
|
||||||
|
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||||
weight = cell_core.weight.data.asnumpy()
|
weight = cell_core.weight.data.asnumpy()
|
||||||
bias = None
|
bias = None
|
||||||
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
||||||
|
@ -302,37 +337,22 @@ class ExportToQuantInferNetwork:
|
||||||
weight_b = weight
|
weight_b = weight
|
||||||
bias_b = bias
|
bias_b = bias
|
||||||
# apply the quant
|
# apply the quant
|
||||||
weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, cell_core.fake_quant_weight.num_bits,
|
quant_min, quant_max = quant_utils.get_quant_min_max(np_type,
|
||||||
cell_core.fake_quant_weight.narrow_range)
|
cell_core.fake_quant_weight.num_bits,
|
||||||
|
cell_core.fake_quant_weight.narrow_range)
|
||||||
|
weight = quant_utils.weight2int(weight, scale_w, zp_w, quant_min, quant_max)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
|
bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
|
||||||
|
|
||||||
# fuse parameter
|
|
||||||
# |--------|47:40|--------|39:32|--------|31:0|
|
|
||||||
# offset_w [8] shift_N [8] deq_scale [32]
|
|
||||||
float32_deq_scale = scale_deq.astype(np.float32)
|
|
||||||
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
|
|
||||||
scale_length = scale_deq.size # channel
|
|
||||||
dequant_param = np.zeros(scale_length, dtype=np.uint64)
|
|
||||||
for index in range(scale_length):
|
|
||||||
dequant_param[index] += uint32_deq_scale[index]
|
|
||||||
scale_deq = Tensor(dequant_param, mstype.uint64)
|
|
||||||
# get op
|
|
||||||
if isinstance(cell_core, quant.DenseQuant):
|
if isinstance(cell_core, quant.DenseQuant):
|
||||||
op_core = P.MatMul()
|
|
||||||
weight = np.transpose(weight)
|
weight = np.transpose(weight)
|
||||||
weight_b = np.transpose(weight_b)
|
weight_b = np.transpose(weight_b)
|
||||||
else:
|
|
||||||
op_core = cell_core.conv
|
|
||||||
weight = Tensor(weight, self.data_type)
|
weight = Tensor(weight, self.data_type)
|
||||||
weight_b = Tensor(weight_b)
|
weight_b = Tensor(weight_b)
|
||||||
if bias_b is not None:
|
if bias_b is not None:
|
||||||
bias_b = Tensor(bias_b, mstype.float32)
|
bias_b = Tensor(bias_b, mstype.float32)
|
||||||
if self.is_mindir:
|
return weight, bias, weight_b, bias_b
|
||||||
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
|
||||||
else:
|
|
||||||
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
|
||||||
return block
|
|
||||||
|
|
||||||
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
|
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
|
||||||
"""add output quant info for quant op for export mindir."""
|
"""add output quant info for quant op for export mindir."""
|
||||||
|
@ -343,6 +363,121 @@ class ExportToQuantInferNetwork:
|
||||||
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
|
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
|
||||||
origin_op.add_prim_attr('output_minq', Tensor(minq))
|
origin_op.add_prim_attr('output_minq', Tensor(minq))
|
||||||
|
|
||||||
|
def _convert_subcell(self, network, change, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell."""
|
||||||
|
if subcell is not None and hasattr(subcell, "fake_quant_weight"):
|
||||||
|
new_subcell = self._get_quant_block(subcell, None, None)
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
self.upcell = new_subcell
|
||||||
|
network.insert_child_to_cell(name, new_subcell)
|
||||||
|
change = True
|
||||||
|
return network, change
|
||||||
|
|
||||||
|
def _convert_conv(self, network, change, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for conv."""
|
||||||
|
cell_core = subcell.conv
|
||||||
|
activation = subcell.activation
|
||||||
|
fake_quant_act = None
|
||||||
|
if hasattr(activation, 'fake_quant_act_before'):
|
||||||
|
fake_quant_act = activation.fake_quant_act_before
|
||||||
|
elif hasattr(activation, 'fake_quant_act'):
|
||||||
|
fake_quant_act = activation.fake_quant_act
|
||||||
|
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
|
||||||
|
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||||
|
self.upcell = None
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
network.insert_child_to_cell(name, new_subcell)
|
||||||
|
change = True
|
||||||
|
return network, change
|
||||||
|
|
||||||
|
def _convert_dense(self, network, change, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for dense."""
|
||||||
|
cell_core = subcell.dense
|
||||||
|
activation = subcell.activation
|
||||||
|
fake_quant_act = None
|
||||||
|
if hasattr(activation, 'fake_quant_act_before'):
|
||||||
|
fake_quant_act = activation.fake_quant_act_before
|
||||||
|
elif hasattr(activation, 'fake_quant_act'):
|
||||||
|
fake_quant_act = activation.fake_quant_act
|
||||||
|
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
|
||||||
|
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
network.insert_child_to_cell(name, new_subcell)
|
||||||
|
self.upcell = None
|
||||||
|
change = True
|
||||||
|
return network, change
|
||||||
|
|
||||||
|
def _convert_act(self, subcell):
|
||||||
|
"""Convert subcell to ant subcell for activation."""
|
||||||
|
activation = subcell.get_origin()
|
||||||
|
if isinstance(activation, nn.ReLU):
|
||||||
|
self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act)
|
||||||
|
elif isinstance(activation, nn.ReLU6):
|
||||||
|
self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act)
|
||||||
|
if self.upcell:
|
||||||
|
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
|
||||||
|
return activation
|
||||||
|
|
||||||
|
def _convert_add(self, subcell):
|
||||||
|
"""Convert subcell to ant subcell for add."""
|
||||||
|
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
|
||||||
|
add_op = subcell.add.subcell
|
||||||
|
subcell.__delattr__("add")
|
||||||
|
subcell.__setattr__("add", add_op)
|
||||||
|
add_op = subcell.add
|
||||||
|
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
|
||||||
|
subcell.__delattr__("fake_quant_act")
|
||||||
|
subcell.__setattr__("fake_quant_act", P.identity())
|
||||||
|
|
||||||
|
def _convert_observer(self, network, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for FakeQuantWithMinMaxObserver."""
|
||||||
|
if self.upcell:
|
||||||
|
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
|
||||||
|
network.__delattr__(name)
|
||||||
|
network.__setattr__(name, P.identity())
|
||||||
|
|
||||||
|
def _convert_fake_quant_after_cell(self, network, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for _AddFakeQuantAfterSubCell."""
|
||||||
|
op = subcell.subcell
|
||||||
|
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
|
||||||
|
network.__delattr__(name)
|
||||||
|
network.__setattr__(name, op)
|
||||||
|
|
||||||
|
def _convert_core_quant_subcell(self, network, change, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for conv and dense."""
|
||||||
|
is_core_subcell = True
|
||||||
|
if isinstance(subcell, nn.Conv2dBnAct):
|
||||||
|
network, change = self._convert_conv(network, change, name, subcell)
|
||||||
|
elif isinstance(subcell, nn.DenseBnAct):
|
||||||
|
network, change = self._convert_dense(network, change, name, subcell)
|
||||||
|
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
|
||||||
|
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
|
||||||
|
network, change = self._convert_subcell(network, change, name, subcell)
|
||||||
|
else:
|
||||||
|
is_core_subcell = False
|
||||||
|
return is_core_subcell, network, change
|
||||||
|
|
||||||
|
def _convert_other_quant_subcell(self, network, change, name, subcell):
|
||||||
|
"""Convert subcell to ant subcell for cell except conv and dense."""
|
||||||
|
is_other_subcell = True
|
||||||
|
if isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
|
||||||
|
activation = self._convert_act(subcell)
|
||||||
|
network.insert_child_to_cell(name, activation)
|
||||||
|
change = True
|
||||||
|
elif isinstance(subcell, nn.TensorAddQuant):
|
||||||
|
self._convert_add(subcell)
|
||||||
|
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
|
||||||
|
self._convert_observer(network, name, subcell)
|
||||||
|
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||||
|
self._convert_fake_quant_after_cell(network, name, subcell)
|
||||||
|
change = True
|
||||||
|
else:
|
||||||
|
is_other_subcell = False
|
||||||
|
return is_other_subcell, network, change
|
||||||
|
|
||||||
def _convert_quant2deploy(self, network):
|
def _convert_quant2deploy(self, network):
|
||||||
"""Convert network's all quant subcell to deploy subcell."""
|
"""Convert network's all quant subcell to deploy subcell."""
|
||||||
cells = network.name_cells()
|
cells = network.name_cells()
|
||||||
|
@ -351,71 +486,11 @@ class ExportToQuantInferNetwork:
|
||||||
subcell = cells[name]
|
subcell = cells[name]
|
||||||
if subcell == network:
|
if subcell == network:
|
||||||
continue
|
continue
|
||||||
if isinstance(subcell, nn.Conv2dBnAct):
|
is_core_quant_subcell, network, change = self._convert_core_quant_subcell(network, change, name, subcell)
|
||||||
network, change = self._convert_subcell(network, change, name, subcell)
|
is_other_quant_subcell, network, change = self._convert_other_quant_subcell(network, change, name, subcell)
|
||||||
elif isinstance(subcell, nn.DenseBnAct):
|
if not is_core_quant_subcell and not is_other_quant_subcell:
|
||||||
network, change = self._convert_subcell(network, change, name, subcell, conv=False)
|
self.upcell = None
|
||||||
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
|
|
||||||
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
|
|
||||||
network, change = self._convert_subcell(network, change, name, subcell, core=False)
|
|
||||||
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
|
|
||||||
activation = subcell.get_origin()
|
|
||||||
if isinstance(activation, nn.ReLU):
|
|
||||||
self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act)
|
|
||||||
elif isinstance(activation, nn.ReLU6):
|
|
||||||
self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act)
|
|
||||||
if self.upcell:
|
|
||||||
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
|
|
||||||
network.insert_child_to_cell(name, activation)
|
|
||||||
change = True
|
|
||||||
elif isinstance(subcell, nn.TensorAddQuant):
|
|
||||||
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
|
|
||||||
add_op = subcell.add.subcell
|
|
||||||
subcell.__delattr__("add")
|
|
||||||
subcell.__setattr__("add", add_op)
|
|
||||||
add_op = subcell.add
|
|
||||||
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
|
|
||||||
subcell.__delattr__("fake_quant_act")
|
|
||||||
subcell.__setattr__("fake_quant_act", P.identity())
|
|
||||||
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
|
|
||||||
if self.upcell:
|
|
||||||
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
|
|
||||||
network.__delattr__(name)
|
|
||||||
network.__setattr__(name, P.identity())
|
|
||||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
|
||||||
op = subcell.subcell
|
|
||||||
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
|
|
||||||
network.__delattr__(name)
|
|
||||||
network.__setattr__(name, op)
|
|
||||||
change = True
|
|
||||||
else:
|
|
||||||
self.upcell, self.upname = None, None
|
|
||||||
self._convert_quant2deploy(subcell)
|
self._convert_quant2deploy(subcell)
|
||||||
if isinstance(network, nn.SequentialCell) and change:
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
network.cell_list = list(network.cells())
|
network.cell_list = list(network.cells())
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
|
|
||||||
"""Convert subcell to ant subcell."""
|
|
||||||
new_subcell = None
|
|
||||||
fake_quant_act = None
|
|
||||||
if core:
|
|
||||||
cell_core = subcell.conv if conv else subcell.dense
|
|
||||||
activation = subcell.activation
|
|
||||||
if hasattr(activation, 'fake_quant_act_before'):
|
|
||||||
fake_quant_act = activation.fake_quant_act_before
|
|
||||||
elif hasattr(activation, 'fake_quant_act'):
|
|
||||||
fake_quant_act = activation.fake_quant_act
|
|
||||||
else:
|
|
||||||
cell_core = subcell
|
|
||||||
activation = None
|
|
||||||
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
|
|
||||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
|
||||||
if new_subcell:
|
|
||||||
prefix = subcell.param_prefix
|
|
||||||
new_subcell.update_parameters_name(prefix + '.')
|
|
||||||
self.upcell = None if core else new_subcell
|
|
||||||
self.upname = None if core else name
|
|
||||||
network.insert_child_to_cell(name, new_subcell)
|
|
||||||
change = True
|
|
||||||
return network, change
|
|
||||||
|
|
|
@ -16,11 +16,6 @@
|
||||||
Compression quant module.
|
Compression quant module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .quantizer import *
|
from .quantizer import OptimizeOption
|
||||||
from .qat import *
|
from .qat import QuantizationAwareTraining, create_quant_config
|
||||||
from .quant_utils import *
|
from .quant_utils import load_nonquant_param_into_quant_net, query_quant_layers
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
__all__.extend(qat.__all__)
|
|
||||||
__all__.extend(quantizer.__all__)
|
|
||||||
__all__.extend(quant_utils.__all__)
|
|
||||||
|
|
|
@ -23,22 +23,20 @@ __all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"]
|
||||||
|
|
||||||
def cal_quantization_params(input_min,
|
def cal_quantization_params(input_min,
|
||||||
input_max,
|
input_max,
|
||||||
|
quant_min,
|
||||||
|
quant_max,
|
||||||
data_type,
|
data_type,
|
||||||
num_bits=8,
|
symmetric=False):
|
||||||
symmetric=False,
|
|
||||||
narrow_range=False,
|
|
||||||
neg_trunc=False):
|
|
||||||
r"""
|
r"""
|
||||||
Calculate quantization params for scale and zero point.
|
Calculate quantization params for scale and zero point.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_min (numpy.ndarray): The dimension of channel or 1.
|
input_min (numpy.ndarray): The dimension of channel or 1.
|
||||||
input_max (numpy.ndarray): The dimension of channel or 1.
|
input_max (numpy.ndarray): The dimension of channel or 1.
|
||||||
|
quant_min (int): The minimum quantization integer.
|
||||||
|
quant_max (int): The maximum quantization integer.
|
||||||
data_type (numpy type) : Can be numpy int8, numpy uint8.
|
data_type (numpy type) : Can be numpy int8, numpy uint8.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
|
||||||
neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
scale (numpy.ndarray): quantization param.
|
scale (numpy.ndarray): quantization param.
|
||||||
|
@ -56,6 +54,24 @@ def cal_quantization_params(input_min,
|
||||||
if (input_max == input_min).all():
|
if (input_max == input_min).all():
|
||||||
return np.ones(input_min.shape), np.zeros(input_min.shape)
|
return np.ones(input_min.shape), np.zeros(input_min.shape)
|
||||||
|
|
||||||
|
# calculate scale
|
||||||
|
if symmetric:
|
||||||
|
input_max = np.maximum(-input_min, input_max)
|
||||||
|
input_min = -input_max
|
||||||
|
scale = (input_max - input_min) / (quant_max - quant_min)
|
||||||
|
|
||||||
|
# calculate zero point
|
||||||
|
if data_type == np.int8 and symmetric:
|
||||||
|
zp = np.zeros(input_min.shape)
|
||||||
|
else:
|
||||||
|
zp_double = quant_min - input_min / scale
|
||||||
|
zp = np.floor(zp_double + 0.5)
|
||||||
|
|
||||||
|
return scale, zp
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_min_max(data_type, num_bits=8, narrow_range=False):
|
||||||
|
"""Calculate quantization params for minimum/maximum quantization integer"""
|
||||||
if data_type == np.int8:
|
if data_type == np.int8:
|
||||||
quant_min = 0 - 2 ** (num_bits - 1)
|
quant_min = 0 - 2 ** (num_bits - 1)
|
||||||
quant_max = 2 ** (num_bits - 1) - 1
|
quant_max = 2 ** (num_bits - 1) - 1
|
||||||
|
@ -66,24 +82,10 @@ def cal_quantization_params(input_min,
|
||||||
raise ValueError("Unsupported datatype({})".format(data_type))
|
raise ValueError("Unsupported datatype({})".format(data_type))
|
||||||
if narrow_range:
|
if narrow_range:
|
||||||
quant_min = quant_min + 1
|
quant_min = quant_min + 1
|
||||||
|
return quant_min, quant_max
|
||||||
# calculate scale
|
|
||||||
if symmetric and not neg_trunc:
|
|
||||||
input_max = np.maximum(-input_min, input_max)
|
|
||||||
input_min = -input_max
|
|
||||||
scale = (input_max - input_min) / (quant_max - quant_min)
|
|
||||||
|
|
||||||
# calculate zero point
|
|
||||||
if data_type == np.int8 and symmetric and not neg_trunc:
|
|
||||||
zp = np.zeros(input_min.shape)
|
|
||||||
else:
|
|
||||||
zp_double = quant_min - input_min / scale
|
|
||||||
zp = np.floor(zp_double + 0.5)
|
|
||||||
|
|
||||||
return scale, zp
|
|
||||||
|
|
||||||
|
|
||||||
def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
|
def weight2int(data, scale, zero_point, quant_min, quant_max):
|
||||||
r"""
|
r"""
|
||||||
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
||||||
|
|
||||||
|
@ -94,9 +96,8 @@ def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=Fals
|
||||||
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
|
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
|
||||||
scale (numpy.ndarray): The dimension of channel or 1.
|
scale (numpy.ndarray): The dimension of channel or 1.
|
||||||
zero_point (numpy.ndarray): The dimension of channel or 1.
|
zero_point (numpy.ndarray): The dimension of channel or 1.
|
||||||
data_type (numpy type) : Can be numpy int8, numpy uint8.
|
quant_min (int): The minimum quantization integer.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
quant_max (int): The maximum quantization integer.
|
||||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
weight (numpy.ndarray): The dimension of channel or 1.
|
weight (numpy.ndarray): The dimension of channel or 1.
|
||||||
|
@ -120,17 +121,6 @@ def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=Fals
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
||||||
|
|
||||||
if data_type == np.int8:
|
|
||||||
quant_min = 0 - 2 ** (num_bits - 1)
|
|
||||||
quant_max = 2 ** (num_bits - 1) - 1
|
|
||||||
elif data_type == np.uint8:
|
|
||||||
quant_min = 0
|
|
||||||
quant_max = 2 ** num_bits - 1
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported weight datatype({})".format(data_type))
|
|
||||||
if narrow_range:
|
|
||||||
quant_min = quant_min + 1
|
|
||||||
|
|
||||||
weight_int = np.round((data / scale) + zero_point)
|
weight_int = np.round((data / scale) + zero_point)
|
||||||
weight_int[weight_int > quant_max] = quant_max
|
weight_int[weight_int > quant_max] = quant_max
|
||||||
weight_int[weight_int < quant_min] = quant_min
|
weight_int[weight_int < quant_min] = quant_min
|
||||||
|
@ -145,54 +135,12 @@ def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
||||||
if cell.mode == 'LEARNED_SCALE':
|
if cell.mode == 'LEARNED_SCALE':
|
||||||
maxq = np.abs(maxq)
|
maxq = np.abs(maxq)
|
||||||
minq = -np.abs(minq)
|
minq = -np.abs(minq)
|
||||||
|
quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range)
|
||||||
|
symmetric = cell.symmetric and not cell.neg_trunc
|
||||||
scale, zp = cal_quantization_params(
|
scale, zp = cal_quantization_params(
|
||||||
minq, maxq, data_type,
|
minq, maxq,
|
||||||
num_bits=cell.num_bits,
|
quant_min, quant_max, data_type,
|
||||||
symmetric=cell.symmetric,
|
symmetric=symmetric)
|
||||||
narrow_range=cell.narrow_range,
|
|
||||||
neg_trunc=cell.neg_trunc)
|
|
||||||
return scale, zp, maxq, minq
|
|
||||||
|
|
||||||
|
|
||||||
def scale_zp_from_data(op, minq, maxq, data_type):
|
|
||||||
r"""
|
|
||||||
Get calculate quantization params for scale and zero point.
|
|
||||||
|
|
||||||
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
|
|
||||||
`mindspore.ops.operation.FakeQuantPerChannel`
|
|
||||||
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
|
||||||
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
|
||||||
data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
scale (numpy.ndarray): quantization param.
|
|
||||||
zero point (numpy.ndarray): quantization param.
|
|
||||||
"""
|
|
||||||
minq = minq.data.asnumpy()
|
|
||||||
maxq = maxq.data.asnumpy()
|
|
||||||
|
|
||||||
scale, zp = cal_quantization_params(
|
|
||||||
minq, maxq, data_type,
|
|
||||||
num_bits=op.num_bits,
|
|
||||||
symmetric=op.symmetric,
|
|
||||||
narrow_range=op.narrow_range)
|
|
||||||
return scale, zp
|
|
||||||
|
|
||||||
|
|
||||||
def scale_zp_max_min_from_data(op, minq, maxq, data_type):
|
|
||||||
"""Get calculate quantization params for scale, zero point, max and min."""
|
|
||||||
minq = minq.data.asnumpy()
|
|
||||||
maxq = maxq.data.asnumpy()
|
|
||||||
|
|
||||||
scale, zp = cal_quantization_params(
|
|
||||||
minq, maxq, data_type,
|
|
||||||
num_bits=op.num_bits,
|
|
||||||
symmetric=op.symmetric,
|
|
||||||
narrow_range=op.narrow_range)
|
|
||||||
return scale, zp, maxq, minq
|
return scale, zp, maxq, minq
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,8 @@ using mindspore::ops::PrimitiveC;
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr int BIT_NUM_8 = 8;
|
||||||
|
constexpr int BIT_NUM_16 = 16;
|
||||||
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||||
|
@ -113,14 +115,14 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
|
||||||
auto repetition_packed = false;
|
auto repetition_packed = false;
|
||||||
MS_LOG(DEBUG) << dst_node->name;
|
MS_LOG(DEBUG) << dst_node->name;
|
||||||
if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
|
if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
|
||||||
if (bit_num <= 8) {
|
if (bit_num <= BIT_NUM_8) {
|
||||||
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
|
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
|
||||||
} else {
|
} else {
|
||||||
repetition_packed = PackRepetition<int16_t>(bit_num, tensor_input);
|
repetition_packed = PackRepetition<int16_t>(bit_num, tensor_input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (bit_num != BIT_NUM_8 && bit_num != BIT_NUM_16 && !repetition_packed &&
|
||||||
if (bit_num != 8 && bit_num != 16 && !repetition_packed && dst_node->quantType != schema::QuantType_QUANT_NONE) {
|
dst_node->quantType != schema::QuantType_QUANT_NONE) {
|
||||||
auto status = DoBitPack(bit_num, tensor_input);
|
auto status = DoBitPack(bit_num, tensor_input);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "do bit pack failed. " << status;
|
MS_LOG(ERROR) << "do bit pack failed. " << status;
|
||||||
|
|
|
@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
|
||||||
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
||||||
std::string outputFileName = homePath + "/" + fileName;
|
std::string outputFileName = homePath + "/" + fileName;
|
||||||
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
||||||
if (outputFile) {
|
if (outputFile != nullptr) {
|
||||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||||
|
|
|
@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
||||||
DIR *pDir;
|
DIR *pDir = nullptr;
|
||||||
struct dirent* ptr;
|
struct dirent* ptr = nullptr;
|
||||||
if (!(pDir = opendir(path.c_str())))
|
if (!(pDir = opendir(path.c_str())))
|
||||||
return;
|
return;
|
||||||
while ((ptr = readdir(pDir)) != 0) {
|
while ((ptr = readdir(pDir)) != 0) {
|
||||||
|
|
|
@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
|
||||||
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
||||||
std::string outputFileName = homePath + "/" + fileName;
|
std::string outputFileName = homePath + "/" + fileName;
|
||||||
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
||||||
if (outputFile) {
|
if (outputFile != nullptr) {
|
||||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||||
|
|
|
@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
||||||
DIR *pDir;
|
DIR *pDir = nullptr;
|
||||||
struct dirent* ptr;
|
struct dirent* ptr = nullptr;
|
||||||
if (!(pDir = opendir(path.c_str())))
|
if (!(pDir = opendir(path.c_str())))
|
||||||
return;
|
return;
|
||||||
while ((ptr = readdir(pDir)) != 0) {
|
while ((ptr = readdir(pDir)) != 0) {
|
||||||
|
|
|
@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
|
||||||
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
||||||
std::string outputFileName = homePath + "/" + fileName;
|
std::string outputFileName = homePath + "/" + fileName;
|
||||||
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
||||||
if (outputFile) {
|
if (outputFile != nullptr) {
|
||||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||||
|
|
|
@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
||||||
DIR *pDir;
|
DIR *pDir = nullptr;
|
||||||
struct dirent* ptr;
|
struct dirent* ptr = nullptr;
|
||||||
if (!(pDir = opendir(path.c_str())))
|
if (!(pDir = opendir(path.c_str())))
|
||||||
return;
|
return;
|
||||||
while ((ptr = readdir(pDir)) != 0) {
|
while ((ptr = readdir(pDir)) != 0) {
|
||||||
|
|
|
@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
|
||||||
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
|
||||||
std::string outputFileName = homePath + "/" + fileName;
|
std::string outputFileName = homePath + "/" + fileName;
|
||||||
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
|
||||||
if (outputFile) {
|
if (outputFile != nullptr) {
|
||||||
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
|
||||||
void* data = aclGetDataBufferAddr(dataBuffer);
|
void* data = aclGetDataBufferAddr(dataBuffer);
|
||||||
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);
|
||||||
|
|
|
@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
|
||||||
DIR *pDir;
|
DIR *pDir = nullptr;
|
||||||
struct dirent* ptr;
|
struct dirent* ptr = nullptr;
|
||||||
if (!(pDir = opendir(path.c_str())))
|
if (!(pDir = opendir(path.c_str())))
|
||||||
return;
|
return;
|
||||||
while ((ptr = readdir(pDir)) != 0) {
|
while ((ptr = readdir(pDir)) != 0) {
|
||||||
|
@ -127,11 +127,13 @@ Result SampleProcess::Process(const char *om_path, const char *input_folder) {
|
||||||
void *inputShapeBuffer = nullptr;
|
void *inputShapeBuffer = nullptr;
|
||||||
int mret = aclrtMalloc(&inputShapeBuffer, 8, ACL_MEM_MALLOC_NORMAL_ONLY);
|
int mret = aclrtMalloc(&inputShapeBuffer, 8, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||||
if (mret != ACL_ERROR_NONE) {
|
if (mret != ACL_ERROR_NONE) {
|
||||||
|
aclrtFree(inputShape);
|
||||||
aclrtFree(inputShapeBuffer);
|
aclrtFree(inputShapeBuffer);
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
mret = aclrtMemcpy(reinterpret_cast<uint8_t *>(inputShapeBuffer), 8, inputShape, 8, ACL_MEMCPY_HOST_TO_DEVICE);
|
mret = aclrtMemcpy(reinterpret_cast<uint8_t *>(inputShapeBuffer), 8, inputShape, 8, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||||
if (mret != ACL_ERROR_NONE) {
|
if (mret != ACL_ERROR_NONE) {
|
||||||
|
aclrtFree(inputShape);
|
||||||
aclrtFree(inputShapeBuffer);
|
aclrtFree(inputShapeBuffer);
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ eval_data_dir=$3
|
||||||
load_ckpt_path=$4
|
load_ckpt_path=$4
|
||||||
|
|
||||||
mkdir -p ms_log
|
mkdir -p ms_log
|
||||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
CUR_DIR=`pwd`
|
CUR_DIR=`pwd`
|
||||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
export GLOG_logtostderr=0
|
export GLOG_logtostderr=0
|
||||||
|
|
|
@ -33,7 +33,7 @@ eval_data_dir=$4
|
||||||
load_ckpt_path=$5
|
load_ckpt_path=$5
|
||||||
|
|
||||||
mkdir -p ms_log
|
mkdir -p ms_log
|
||||||
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
|
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||||
CUR_DIR=`pwd`
|
CUR_DIR=`pwd`
|
||||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
export GLOG_logtostderr=0
|
export GLOG_logtostderr=0
|
||||||
|
|
Loading…
Reference in New Issue