!17432 [MS][LITE][CPU] lsq quant_aware_training method test on mslite

From: @zhang__sss
Reviewed-by: @zlq2020,@hangangqiang
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-06-02 09:18:21 +08:00 committed by Gitee
commit b96dac51bd
16 changed files with 252 additions and 183 deletions

View File

@ -23,8 +23,11 @@ from ..._checkparam import Validator
from ...common import Tensor
from ...common import dtype as mstype
from ...common.api import _executor
from ...common.parameter import Parameter
from ...nn import Cell
from ...nn.layer import quant
from ...ops import operations as P
from ...ops import functional as F
from ...ops.operations import _inner_ops as inner
from ..quant import quant_utils
from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
@ -32,6 +35,151 @@ from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
__all__ = ["ExportToQuantInferNetwork"]
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with Quant and DeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
"""
def __init__(self,
core_op,
weight,
quant_op,
dequant_op,
dequant_scale,
bias=None,
activation=None):
super(QuantBlock, self).__init__()
self.core_op = core_op
self.weight = weight
self.quant = quant_op
self.dequant = dequant_op
self.dequant_scale = dequant_scale
self.bias = bias
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
self.bias_add = P.BiasAdd()
self.sub = P.Sub()
self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
def construct(self, x):
x = self.quant(x)
if self.has_bias:
weight = self.sub(self.weight, self.weight_offset)
x = self.core_op(x, weight)
x = self.bias_add(x, self.bias)
else:
x = self.core_op(x, self.weight)
x = self.dequant(x, self.dequant_scale)
x = F.cast(x, mstype.float32)
if self.has_act:
x = self.activation(x)
return x
def extend_repr(self):
s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias:
s += f', bias=shape[{self.bias.shape}]'
if self.has_act:
s += f', activation={self.activation}'
s += f', dequant={self.dequant}'
return s
class QuantMindirBlock(Cell):
"""A quant binary block of Conv/Dense, activation layer for export MINDIR model.
Args:
core_op (Cell): The operation cell.
weight (Tensor): The weight of the cell.
bias (Tensor): The bias of the cell. Default: None.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
param_dict (dict): The information of the cell.
"""
def __init__(self,
core_op,
weight,
bias=None,
activation=None,
param_dict=None):
super(QuantMindirBlock, self).__init__()
self.core_op = core_op
if activation is not None:
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
if param_dict["output_maxq"] is not None:
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
if hasattr(core_op, 'pad_mode'):
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
self.core_op.add_prim_attr("act_num_bits", Tensor(8))
self.core_op.add_prim_attr("weight_num_bits", Tensor(param_dict["weight_num_bits"]))
self.core_op.add_prim_attr("weight_narrow_range", Tensor(param_dict["weight_narrow_range"]))
if param_dict["input_narrow_range"] is not None:
self.core_op.add_prim_attr("input_narrow_range", Tensor(param_dict["input_narrow_range"]))
if param_dict["output_narrow_range"] is not None:
self.core_op.add_prim_attr("output_narrow_range", Tensor(param_dict["output_narrow_range"]))
if param_dict["input_maxq"] == 'None':
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
elif param_dict["input_maxq"] is not None:
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
self.weight = weight
self.bias = bias
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
self.bias_add = P.BiasAdd()
def construct(self, x):
if self.has_bias:
x = self.core_op(x, self.weight)
x = self.bias_add(x, self.bias)
else:
x = self.core_op(x, self.weight)
if self.has_act:
x = self.activation(x)
return x
def extend_repr(self):
s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias:
s += f', bias=shape[{self.bias.shape}]'
if self.has_act:
s += f', activation={self.activation}'
return s
class ExportToQuantInferNetwork:
"""
Convert quantization aware network to infer network.
@ -92,15 +240,20 @@ class ExportToQuantInferNetwork:
param_dict["output_minq"] = None
param_dict["input_maxq"] = None
param_dict["input_minq"] = None
param_dict["input_narrow_range"] = None
param_dict["output_narrow_range"] = None
param_dict["weight_narrow_range"] = cell_core.fake_quant_weight.narrow_range
param_dict["mean"] = self.mean
param_dict["std_dev"] = self.std_dev
param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
param_dict["weight_num_bits"] = cell_core.fake_quant_weight.num_bits
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
if fake_quant_a_out is not None:
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
param_dict["output_narrow_range"] = fake_quant_a_out.narrow_range
info = self.quant_info_table.get(w_minq_name, None)
if not info:
@ -120,6 +273,7 @@ class ExportToQuantInferNetwork:
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
else:
# skip quant layer
scale_a_in, zp_a_in = 1.0, 0.0
@ -175,9 +329,9 @@ class ExportToQuantInferNetwork:
if bias_b is not None:
bias_b = Tensor(bias_b, mstype.float32)
if self.is_mindir:
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
else:
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):

View File

@ -138,6 +138,10 @@ def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`."""
minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy()
# make sure maxq > 0 and minq <= 0
if cell.mode == 'LEARNED_SCALE':
maxq = np.abs(maxq)
minq = -np.abs(minq)
scale, zp = cal_quantization_params(
minq, maxq, data_type,

View File

@ -69,14 +69,10 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
const auto zeroPoint = quantParam->zeroPoint;
const auto numBit = quantParam->numBits;
const auto narrowRange = quantParam->narrowRange;
double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
double minLimit;
if (narrowRange) {
minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
} else {
minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
}
const int32_t quantMax = (1 << (unsigned int)(numBit - 1)) - 1;
const int32_t quantMin = -1 * (1 << (unsigned int)(numBit - 1)) + (narrowRange ? 1 : 0);
const double maxLimit = static_cast<float>(quantMax - zeroPoint) * scale;
const double minLimit = static_cast<float>(quantMin - zeroPoint) * scale;
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
double tmp;

View File

@ -118,7 +118,7 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
}
}
if (bit_num != 8 && bit_num != 16 && !repetition_packed) {
if (bit_num != 8 && bit_num != 16 && !repetition_packed && dst_node->quantType != schema::QuantType_QUANT_NONE) {
auto status = DoBitPack(bit_num, tensor_input);
if (status != RET_OK) {
MS_LOG(ERROR) << "do bit pack failed. " << status;

View File

@ -27,7 +27,8 @@
namespace mindspore {
namespace lite {
namespace {
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
int ConvertInputQuantParam(const PrimitivePtr &prim, bool input_narrow_range, bool weight_narrow_range,
int32_t act_numbits, int32_t weight_numbits) {
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quant_param;
@ -40,8 +41,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
auto *max_buf = static_cast<float *>(input_max_ptr->data_c());
quant_param.min = *min_buf;
quant_param.max = *max_buf;
auto ret =
lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, narrow_range, numbits);
auto ret = lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, input_narrow_range,
act_numbits);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Can't calculate quant parameters";
return ret;
@ -64,8 +65,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
schema::QuantParamT tmp_quant_param;
tmp_quant_param.min = *min_buf;
tmp_quant_param.max = *max_buf;
auto ret =
lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, true, numbits);
auto ret = lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max,
weight_narrow_range, weight_numbits);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Can't calculate quant parameters";
return ret;
@ -104,39 +105,77 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
return lite::RET_OK;
}
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim->GetAttr("narrow_range");
bool narrow_range_param = false;
int GetNarrowRange(const PrimitivePtr &prim, const std::string &narrow_range_str, bool *narrow_range_param) {
auto narrow_range = prim->GetAttr(narrow_range_str);
if (narrow_range != nullptr) {
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
*narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
narrow_range_param = GetValue<bool>(narrow_range);
*narrow_range_param = GetValue<bool>(narrow_range);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return lite::RET_ERROR;
}
}
auto num_bits = prim->GetAttr("num_bits");
int32_t num_bits_param = 8;
return lite::RET_OK;
}
int GetNumBits(const PrimitivePtr &prim, const std::string &num_bits_str, int *num_bits_param) {
auto num_bits = prim->GetAttr(num_bits_str);
if (num_bits != nullptr) {
if (utils::isa<tensor::TensorPtr>(num_bits)) {
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
*num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
num_bits_param = GetValue<int64_t>(num_bits);
*num_bits_param = GetValue<int64_t>(num_bits);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return lite::RET_ERROR;
}
}
auto status = ConvertInputQuantParam(prim, narrow_range_param, num_bits_param);
return lite::RET_OK;
}
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
bool input_narrow_range_param = false;
auto status = GetNarrowRange(prim, "input_narrow_range", &input_narrow_range_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get input narrow range failed.";
return status;
}
bool weight_narrow_range_param = true;
status = GetNarrowRange(prim, "weight_narrow_range", &weight_narrow_range_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get weight narrow range failed.";
return status;
}
bool output_narrow_range_param = false;
status = GetNarrowRange(prim, "output_narrow_range", &output_narrow_range_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get output narrow range failed.";
return status;
}
int32_t act_num_bits_param = 8;
status = GetNumBits(prim, "act_num_bits", &act_num_bits_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get activation num_bits failed.";
return status;
}
int32_t weight_num_bits_param = 8;
status = GetNumBits(prim, "weight_num_bits", &weight_num_bits_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "get weight num_bits failed.";
return status;
}
status = ConvertInputQuantParam(prim, input_narrow_range_param, weight_narrow_range_param, act_num_bits_param,
weight_num_bits_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "compute int quant param failed.";
return status;
}
status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param);
status = ConvertOutputQuantParam(prim, output_narrow_range_param, act_num_bits_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "compute output quant param failed.";
return status;

View File

@ -175,6 +175,11 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "compute tensor to int8 prechannel failed.";
return RET_ERROR;
}
int bit_num = tensor->quantParams.front()->numBits;
if (DoBitPack(bit_num, tensor.get()) != RET_OK) {
MS_LOG(ERROR) << "bit pack failed.";
return RET_ERROR;
}
index++;
continue;
}
@ -183,6 +188,11 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
if (quantParam->dstDtype == TypeId::kNumberTypeInt8 || quantParam->dstDtype == TypeId::kNumberTypeUInt8 ||
quantParam->dstDtype == TypeId::kNumberTypeFloat32 || quantParam->dstDtype == TypeId::kNumberTypeFloat) {
status = ComputeDataToInt8(tensor, index);
int bit_num = tensor->quantParams.front()->numBits;
if (DoBitPack(bit_num, tensor.get()) != RET_OK) {
MS_LOG(ERROR) << "bit pack failed.";
return RET_ERROR;
}
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
// quant bias data
status = ComputeDataToInt32(tensor);

View File

@ -297,8 +297,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
const int8_t quantMin = std::numeric_limits<int8_t>::min() + (narrowRange ? 1 : 0);
const int8_t quantMax = std::numeric_limits<int8_t>::max();
const int8_t quantMax = (1 << (unsigned int)(numBits - 1)) - 1;
const int8_t quantMin = -1 * (1 << (unsigned int)(numBits - 1)) + (narrowRange ? 1 : 0);
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {

View File

@ -20,7 +20,6 @@ import numpy as np
import mindspore.common.dtype as mstype
from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
@ -1638,144 +1637,3 @@ class MulQuant(Cell):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with Quant and DeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
"""
def __init__(self,
core_op,
weight,
quant_op,
dequant_op,
dequant_scale,
bias=None,
activation=None):
super(QuantBlock, self).__init__()
self.core_op = core_op
self.weight = weight
self.quant = quant_op
self.dequant = dequant_op
self.dequant_scale = dequant_scale
self.bias = bias
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
self.bias_add = P.BiasAdd()
self.sub = P.Sub()
self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
def construct(self, x):
x = self.quant(x)
if self.has_bias:
weight = self.sub(self.weight, self.weight_offset)
x = self.core_op(x, weight)
x = self.bias_add(x, self.bias)
else:
x = self.core_op(x, self.weight)
x = self.dequant(x, self.dequant_scale)
x = F.cast(x, mstype.float32)
if self.has_act:
x = self.activation(x)
return x
def extend_repr(self):
s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias:
s += f', bias=shape[{self.bias.shape}]'
if self.has_act:
s += f', activation={self.activation}'
s += f', dequant={self.dequant}'
return s
class QuantMindirBlock(Cell):
"""A quant binary block of Conv/Dense, activation layer for export MINDIR model.
Args:
core_op (Cell): The operation cell.
weight (Tensor): The weight of the cell.
bias (Tensor): The bias of the cell. Default: None.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
param_dict (dict): The information of the cell.
"""
def __init__(self,
core_op,
weight,
bias=None,
activation=None,
param_dict=None):
super(QuantMindirBlock, self).__init__()
self.core_op = core_op
if activation is not None:
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
if param_dict["output_maxq"] is not None:
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
if hasattr(core_op, 'pad_mode'):
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
self.core_op.add_prim_attr("num_bits", Tensor(8))
self.core_op.add_prim_attr("narrow_range", Tensor(False))
if param_dict["input_maxq"] == 'None':
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
elif param_dict["input_maxq"] is not None:
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
self.weight = weight
self.bias = bias
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
self.bias_add = P.BiasAdd()
def construct(self, x):
if self.has_bias:
x = self.core_op(x, self.weight)
x = self.bias_add(x, self.bias)
else:
x = self.core_op(x, self.weight)
if self.has_act:
x = self.activation(x)
return x
def extend_repr(self):
s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias:
s += f', bias=shape[{self.bias.shape}]'
if self.has_act:
s += f', activation={self.activation}'
return s

View File

@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
SampleProcess.cpp

View File

@ -36,7 +36,7 @@ if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "air name: "$model
echo "dataset path: "$data_path
echo "label path: "$label_path
echo "device id: "$device_id

View File

@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
SampleProcess.cpp

View File

@ -36,7 +36,7 @@ if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "air name: "$model
echo "dataset path: "$data_path
echo "label path: "$label_path
echo "device id: "$device_id

View File

@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
SampleProcess.cpp

View File

@ -36,7 +36,7 @@ if [ $# == 4 ]; then
device_id=$4
fi
echo "mindir name: "$model
echo "air name: "$model
echo "dataset path: "$data_path
echo "label path: "$label_path
echo "device id: "$device_id

View File

@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
# Set include directory and library directory
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
# Header path
include_directories(${ACL_LIB_DIR}/include/)
include_directories(${FWKACL_LIB_DIR}/include/)
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
include_directories(${PROJECT_SRC_ROOT}/../inc)
# add host lib path
link_directories(${ACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
add_executable(main utils.cpp
SampleProcess.cpp

View File

@ -38,7 +38,7 @@ if [ $# == 6 ]; then
device_id=$6
fi
echo "mindir name: "$model
echo "air name: "$model
echo "dataset path: "$data_path
echo "annotation path: "$anno_path
echo "image shape path: "$image_shape_path