code clean

This commit is contained in:
zhang__sss 2021-07-09 14:42:23 +08:00
parent a27c7b2436
commit dd1f1c26d9
15 changed files with 234 additions and 215 deletions

View File

@ -16,7 +16,4 @@
Compression common module.
"""
from .constant import *
__all__ = []
__all__.extend(constant.__all__)
from .constant import QuantDtype

View File

@ -35,6 +35,7 @@ from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
__all__ = ["ExportToQuantInferNetwork"]
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
@ -180,6 +181,7 @@ class QuantMindirBlock(Cell):
s += f', activation={self.activation}'
return s
class ExportToQuantInferNetwork:
"""
Convert quantization aware network to infer network.
@ -199,18 +201,13 @@ class ExportToQuantInferNetwork:
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = Validator.check_isinstance('network', network, (nn.Cell,))
self.input_scale = 1 / std_dev
self.input_zero_point = round(mean)
self.data_type = mstype.int8
self.network = copy.deepcopy(network)
self.network_bk = copy.deepcopy(network)
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
self.get_inputs_table(inputs)
self.mean = mean
self.std_dev = std_dev
self.is_mindir = is_mindir
self.upcell = None
self.upname = None
def get_inputs_table(self, inputs):
"""Get the input quantization parameters of quantization cell for quant export."""
@ -220,6 +217,7 @@ class ExportToQuantInferNetwork:
def run(self):
"""Start to convert."""
self.network_bk = copy.deepcopy(self.network)
self.network.update_cell_prefix()
network = self.network
if isinstance(network, _AddFakeQuantInput):
@ -229,7 +227,36 @@ class ExportToQuantInferNetwork:
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
"""convert network's quant subcell to deploy subcell"""
# Calculate the scale and zero point
scale_a_in, zp_a_in, scale_w, zp_w, param_dict = self.__get_quant_param(cell_core, fake_quant_a_out)
# Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here.
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
scale_deq = self.__get_dequant_scale(scale_a_in, scale_w)
dequant_op = inner.Dequant()
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
activation = activation.get_origin()
# get op
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
else:
op_core = cell_core.conv
# get the `weight` and `bias`
weight, bias, weight_b, bias_b = self.__get_weight_bias(cell_core, scale_a_in, scale_w, zp_w)
if self.is_mindir:
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
else:
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def __get_quant_param(self, cell_core, fake_quant_a_out):
"""get parameter for quant block"""
w_minq_name = cell_core.fake_quant_weight.minq.name
w_maxq_name = cell_core.fake_quant_weight.maxq.name
np_type = mstype.dtype_to_nptype(self.data_type)
@ -262,7 +289,7 @@ class ExportToQuantInferNetwork:
_, minq_name = info
if minq_name == 'input':
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
self.input_scale, self.input_zero_point, 'None', 'None'
(1 / self.std_dev), round(self.mean), 'None', 'None'
else:
fake_quant_a_in_prefix = minq_name[:-5]
cells = self.network_bk.cells_and_names()
@ -270,26 +297,34 @@ class ExportToQuantInferNetwork:
if cell[0].endswith(fake_quant_a_in_prefix):
fake_quant_a_in = cell[1]
break
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
else:
# skip quant layer
scale_a_in, zp_a_in = 1.0, 0.0
return scale_a_in, zp_a_in, scale_w, zp_w, param_dict
# Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here.
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
@staticmethod
def __get_dequant_scale(scale_a_in, scale_w):
"""Get dequant scale"""
scale_deq = scale_a_in * scale_w
dequant_op = inner.Dequant()
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
activation = activation.get_origin()
# fuse parameter
# |--------|47:40|--------|39:32|--------|31:0|
# offset_w [8] shift_N [8] deq_scale [32]
float32_deq_scale = scale_deq.astype(np.float32)
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
scale_length = scale_deq.size # channel
dequant_param = np.zeros(scale_length, dtype=np.uint64)
for index in range(scale_length):
dequant_param[index] += uint32_deq_scale[index]
scale_deq = Tensor(dequant_param, mstype.uint64)
return scale_deq
# get the `weight` and `bias`
def __get_weight_bias(self, cell_core, scale_a_in, scale_w, zp_w):
"""Get weight and bias for quantizaiton"""
np_type = mstype.dtype_to_nptype(self.data_type)
weight = cell_core.weight.data.asnumpy()
bias = None
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
@ -302,37 +337,22 @@ class ExportToQuantInferNetwork:
weight_b = weight
bias_b = bias
# apply the quant
weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, cell_core.fake_quant_weight.num_bits,
cell_core.fake_quant_weight.narrow_range)
quant_min, quant_max = quant_utils.get_quant_min_max(np_type,
cell_core.fake_quant_weight.num_bits,
cell_core.fake_quant_weight.narrow_range)
weight = quant_utils.weight2int(weight, scale_w, zp_w, quant_min, quant_max)
if bias is not None:
bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
# fuse parameter
# |--------|47:40|--------|39:32|--------|31:0|
# offset_w [8] shift_N [8] deq_scale [32]
float32_deq_scale = scale_deq.astype(np.float32)
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
scale_length = scale_deq.size # channel
dequant_param = np.zeros(scale_length, dtype=np.uint64)
for index in range(scale_length):
dequant_param[index] += uint32_deq_scale[index]
scale_deq = Tensor(dequant_param, mstype.uint64)
# get op
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
weight = np.transpose(weight)
weight_b = np.transpose(weight_b)
else:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
weight_b = Tensor(weight_b)
if bias_b is not None:
bias_b = Tensor(bias_b, mstype.float32)
if self.is_mindir:
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
else:
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
return weight, bias, weight_b, bias_b
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
"""add output quant info for quant op for export mindir."""
@ -343,6 +363,121 @@ class ExportToQuantInferNetwork:
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
origin_op.add_prim_attr('output_minq', Tensor(minq))
def _convert_subcell(self, network, change, name, subcell):
"""Convert subcell to ant subcell."""
if subcell is not None and hasattr(subcell, "fake_quant_weight"):
new_subcell = self._get_quant_block(subcell, None, None)
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
self.upcell = new_subcell
network.insert_child_to_cell(name, new_subcell)
change = True
return network, change
def _convert_conv(self, network, change, name, subcell):
"""Convert subcell to ant subcell for conv."""
cell_core = subcell.conv
activation = subcell.activation
fake_quant_act = None
if hasattr(activation, 'fake_quant_act_before'):
fake_quant_act = activation.fake_quant_act_before
elif hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
self.upcell = None
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
return network, change
def _convert_dense(self, network, change, name, subcell):
"""Convert subcell to ant subcell for dense."""
cell_core = subcell.dense
activation = subcell.activation
fake_quant_act = None
if hasattr(activation, 'fake_quant_act_before'):
fake_quant_act = activation.fake_quant_act_before
elif hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
self.upcell = None
change = True
return network, change
def _convert_act(self, subcell):
"""Convert subcell to ant subcell for activation."""
activation = subcell.get_origin()
if isinstance(activation, nn.ReLU):
self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act)
elif isinstance(activation, nn.ReLU6):
self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act)
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
return activation
def _convert_add(self, subcell):
"""Convert subcell to ant subcell for add."""
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
add_op = subcell.add.subcell
subcell.__delattr__("add")
subcell.__setattr__("add", add_op)
add_op = subcell.add
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
subcell.__delattr__("fake_quant_act")
subcell.__setattr__("fake_quant_act", P.identity())
def _convert_observer(self, network, name, subcell):
"""Convert subcell to ant subcell for FakeQuantWithMinMaxObserver."""
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
network.__delattr__(name)
network.__setattr__(name, P.identity())
def _convert_fake_quant_after_cell(self, network, name, subcell):
"""Convert subcell to ant subcell for _AddFakeQuantAfterSubCell."""
op = subcell.subcell
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
network.__delattr__(name)
network.__setattr__(name, op)
def _convert_core_quant_subcell(self, network, change, name, subcell):
"""Convert subcell to ant subcell for conv and dense."""
is_core_subcell = True
if isinstance(subcell, nn.Conv2dBnAct):
network, change = self._convert_conv(network, change, name, subcell)
elif isinstance(subcell, nn.DenseBnAct):
network, change = self._convert_dense(network, change, name, subcell)
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
network, change = self._convert_subcell(network, change, name, subcell)
else:
is_core_subcell = False
return is_core_subcell, network, change
def _convert_other_quant_subcell(self, network, change, name, subcell):
"""Convert subcell to ant subcell for cell except conv and dense."""
is_other_subcell = True
if isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
activation = self._convert_act(subcell)
network.insert_child_to_cell(name, activation)
change = True
elif isinstance(subcell, nn.TensorAddQuant):
self._convert_add(subcell)
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
self._convert_observer(network, name, subcell)
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
self._convert_fake_quant_after_cell(network, name, subcell)
change = True
else:
is_other_subcell = False
return is_other_subcell, network, change
def _convert_quant2deploy(self, network):
"""Convert network's all quant subcell to deploy subcell."""
cells = network.name_cells()
@ -351,71 +486,11 @@ class ExportToQuantInferNetwork:
subcell = cells[name]
if subcell == network:
continue
if isinstance(subcell, nn.Conv2dBnAct):
network, change = self._convert_subcell(network, change, name, subcell)
elif isinstance(subcell, nn.DenseBnAct):
network, change = self._convert_subcell(network, change, name, subcell, conv=False)
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv,
quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)):
network, change = self._convert_subcell(network, change, name, subcell, core=False)
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
activation = subcell.get_origin()
if isinstance(activation, nn.ReLU):
self._add_output_min_max_for_op(activation.relu, subcell.fake_quant_act)
elif isinstance(activation, nn.ReLU6):
self._add_output_min_max_for_op(activation.relu6, subcell.fake_quant_act)
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
network.insert_child_to_cell(name, activation)
change = True
elif isinstance(subcell, nn.TensorAddQuant):
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
add_op = subcell.add.subcell
subcell.__delattr__("add")
subcell.__setattr__("add", add_op)
add_op = subcell.add
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
subcell.__delattr__("fake_quant_act")
subcell.__setattr__("fake_quant_act", P.identity())
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
network.__delattr__(name)
network.__setattr__(name, P.identity())
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
network.__delattr__(name)
network.__setattr__(name, op)
change = True
else:
self.upcell, self.upname = None, None
is_core_quant_subcell, network, change = self._convert_core_quant_subcell(network, change, name, subcell)
is_other_quant_subcell, network, change = self._convert_other_quant_subcell(network, change, name, subcell)
if not is_core_quant_subcell and not is_other_quant_subcell:
self.upcell = None
self._convert_quant2deploy(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
return network
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
"""Convert subcell to ant subcell."""
new_subcell = None
fake_quant_act = None
if core:
cell_core = subcell.conv if conv else subcell.dense
activation = subcell.activation
if hasattr(activation, 'fake_quant_act_before'):
fake_quant_act = activation.fake_quant_act_before
elif hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act
else:
cell_core = subcell
activation = None
if cell_core is not None and hasattr(cell_core, "fake_quant_weight"):
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell:
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
self.upcell = None if core else new_subcell
self.upname = None if core else name
network.insert_child_to_cell(name, new_subcell)
change = True
return network, change

View File

@ -16,11 +16,6 @@
Compression quant module.
"""
from .quantizer import *
from .qat import *
from .quant_utils import *
__all__ = []
__all__.extend(qat.__all__)
__all__.extend(quantizer.__all__)
__all__.extend(quant_utils.__all__)
from .quantizer import OptimizeOption
from .qat import QuantizationAwareTraining, create_quant_config
from .quant_utils import load_nonquant_param_into_quant_net, query_quant_layers

View File

@ -23,22 +23,20 @@ __all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"]
def cal_quantization_params(input_min,
input_max,
quant_min,
quant_max,
data_type,
num_bits=8,
symmetric=False,
narrow_range=False,
neg_trunc=False):
symmetric=False):
r"""
Calculate quantization params for scale and zero point.
Args:
input_min (numpy.ndarray): The dimension of channel or 1.
input_max (numpy.ndarray): The dimension of channel or 1.
quant_min (int): The minimum quantization integer.
quant_max (int): The maximum quantization integer.
data_type (numpy type) : Can be numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False.
Returns:
scale (numpy.ndarray): quantization param.
@ -56,6 +54,24 @@ def cal_quantization_params(input_min,
if (input_max == input_min).all():
return np.ones(input_min.shape), np.zeros(input_min.shape)
# calculate scale
if symmetric:
input_max = np.maximum(-input_min, input_max)
input_min = -input_max
scale = (input_max - input_min) / (quant_max - quant_min)
# calculate zero point
if data_type == np.int8 and symmetric:
zp = np.zeros(input_min.shape)
else:
zp_double = quant_min - input_min / scale
zp = np.floor(zp_double + 0.5)
return scale, zp
def get_quant_min_max(data_type, num_bits=8, narrow_range=False):
"""Calculate quantization params for minimum/maximum quantization integer"""
if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
@ -66,24 +82,10 @@ def cal_quantization_params(input_min,
raise ValueError("Unsupported datatype({})".format(data_type))
if narrow_range:
quant_min = quant_min + 1
# calculate scale
if symmetric and not neg_trunc:
input_max = np.maximum(-input_min, input_max)
input_min = -input_max
scale = (input_max - input_min) / (quant_max - quant_min)
# calculate zero point
if data_type == np.int8 and symmetric and not neg_trunc:
zp = np.zeros(input_min.shape)
else:
zp_double = quant_min - input_min / scale
zp = np.floor(zp_double + 0.5)
return scale, zp
return quant_min, quant_max
def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
def weight2int(data, scale, zero_point, quant_min, quant_max):
r"""
Calculate int8/uint8 weight from fp32. the formula is defined as:
@ -94,9 +96,8 @@ def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=Fals
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
scale (numpy.ndarray): The dimension of channel or 1.
zero_point (numpy.ndarray): The dimension of channel or 1.
data_type (numpy type) : Can be numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_min (int): The minimum quantization integer.
quant_max (int): The maximum quantization integer.
Returns:
weight (numpy.ndarray): The dimension of channel or 1.
@ -120,17 +121,6 @@ def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=Fals
else:
raise ValueError("Unsupported weight shape({})".format(data.shape))
if data_type == np.int8:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
elif data_type == np.uint8:
quant_min = 0
quant_max = 2 ** num_bits - 1
else:
raise ValueError("Unsupported weight datatype({})".format(data_type))
if narrow_range:
quant_min = quant_min + 1
weight_int = np.round((data / scale) + zero_point)
weight_int[weight_int > quant_max] = quant_max
weight_int[weight_int < quant_min] = quant_min
@ -145,54 +135,12 @@ def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
if cell.mode == 'LEARNED_SCALE':
maxq = np.abs(maxq)
minq = -np.abs(minq)
quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range)
symmetric = cell.symmetric and not cell.neg_trunc
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=cell.num_bits,
symmetric=cell.symmetric,
narrow_range=cell.narrow_range,
neg_trunc=cell.neg_trunc)
return scale, zp, maxq, minq
def scale_zp_from_data(op, minq, maxq, data_type):
r"""
Get calculate quantization params for scale and zero point.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
`mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq = minq.data.asnumpy()
maxq = maxq.data.asnumpy()
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp
def scale_zp_max_min_from_data(op, minq, maxq, data_type):
"""Get calculate quantization params for scale, zero point, max and min."""
minq = minq.data.asnumpy()
maxq = maxq.data.asnumpy()
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
minq, maxq,
quant_min, quant_max, data_type,
symmetric=symmetric)
return scale, zp, maxq, minq

View File

@ -46,6 +46,8 @@ using mindspore::ops::PrimitiveC;
namespace mindspore::lite {
namespace {
constexpr int BIT_NUM_8 = 8;
constexpr int BIT_NUM_16 = 16;
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
@ -113,14 +115,14 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
auto repetition_packed = false;
MS_LOG(DEBUG) << dst_node->name;
if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
if (bit_num <= 8) {
if (bit_num <= BIT_NUM_8) {
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
} else {
repetition_packed = PackRepetition<int16_t>(bit_num, tensor_input);
}
}
if (bit_num != 8 && bit_num != 16 && !repetition_packed && dst_node->quantType != schema::QuantType_QUANT_NONE) {
if (bit_num != BIT_NUM_8 && bit_num != BIT_NUM_16 && !repetition_packed &&
dst_node->quantType != schema::QuantType_QUANT_NONE) {
auto status = DoBitPack(bit_num, tensor_input);
if (status != RET_OK) {
MS_LOG(ERROR) << "do bit pack failed. " << status;

View File

@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
std::string outputFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
if (outputFile != nullptr) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);

View File

@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir;
struct dirent* ptr;
DIR *pDir = nullptr;
struct dirent* ptr = nullptr;
if (!(pDir = opendir(path.c_str())))
return;
while ((ptr = readdir(pDir)) != 0) {

View File

@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
std::string outputFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
if (outputFile != nullptr) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);

View File

@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir;
struct dirent* ptr;
DIR *pDir = nullptr;
struct dirent* ptr = nullptr;
if (!(pDir = opendir(path.c_str())))
return;
while ((ptr = readdir(pDir)) != 0) {

View File

@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
std::string outputFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
if (outputFile != nullptr) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);

View File

@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir;
struct dirent* ptr;
DIR *pDir = nullptr;
struct dirent* ptr = nullptr;
if (!(pDir = opendir(path.c_str())))
return;
while ((ptr = readdir(pDir)) != 0) {

View File

@ -181,7 +181,7 @@ void ModelProcess::DumpModelOutputResult(char *output_name) {
std::string fileName = std::string(output_name) + '_' + std::to_string(i) + ".bin";
std::string outputFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outputFileName.c_str(), "wb");
if (outputFile) {
if (outputFile != nullptr) {
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(output_, i);
void* data = aclGetDataBufferAddr(dataBuffer);
uint32_t len = aclGetDataBufferSizeV2(dataBuffer);

View File

@ -82,8 +82,8 @@ Result SampleProcess::InitResource(const char *aclConfigPath) {
}
void SampleProcess::GetAllFiles(std::string path, std::vector<string> *files) {
DIR *pDir;
struct dirent* ptr;
DIR *pDir = nullptr;
struct dirent* ptr = nullptr;
if (!(pDir = opendir(path.c_str())))
return;
while ((ptr = readdir(pDir)) != 0) {
@ -127,11 +127,13 @@ Result SampleProcess::Process(const char *om_path, const char *input_folder) {
void *inputShapeBuffer = nullptr;
int mret = aclrtMalloc(&inputShapeBuffer, 8, ACL_MEM_MALLOC_NORMAL_ONLY);
if (mret != ACL_ERROR_NONE) {
aclrtFree(inputShape);
aclrtFree(inputShapeBuffer);
return FAILED;
}
mret = aclrtMemcpy(reinterpret_cast<uint8_t *>(inputShapeBuffer), 8, inputShape, 8, ACL_MEMCPY_HOST_TO_DEVICE);
if (mret != ACL_ERROR_NONE) {
aclrtFree(inputShape);
aclrtFree(inputShapeBuffer);
return FAILED;
}

View File

@ -32,7 +32,7 @@ eval_data_dir=$3
load_ckpt_path=$4
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0

View File

@ -33,7 +33,7 @@ eval_data_dir=$4
load_ckpt_path=$5
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")"; pwd)
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0