[MSLITE][Develop] remove mul op weight quant

This commit is contained in:
yangruoqi713 2020-09-24 16:03:06 +08:00
parent 6fd4848a63
commit cd40cbbfb2
3 changed files with 7 additions and 40 deletions

View File

@ -192,41 +192,19 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "opParameter is nullptr";
return nullptr;
}
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->SetData(dequant_weight);
}
auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "New kernel fails.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return kernel;
}

View File

@ -617,17 +617,9 @@ basepath=$(pwd)
echo ${basepath}
#set -e
# Example:sh run_benchmark_nets.sh -a /home/temp_test -c /home/temp_test -r /home/temp_test -m /home/temp_test/models -d "8KE5T19620002408"
while getopts "a:c:r:m:d:" opt; do
# Example:sh run_benchmark_nets.sh -r /home/temp_test -m /home/temp_test/models -d "8KE5T19620002408"
while getopts "r:m:d:" opt; do
case ${opt} in
a)
arm_path=${OPTARG}
echo "arm_path is ${OPTARG}"
;;
c)
converter_path=${OPTARG}
echo "converter_path is ${OPTARG}"
;;
r)
release_path=${OPTARG}
echo "release_path is ${OPTARG}"
@ -646,9 +638,6 @@ while getopts "a:c:r:m:d:" opt; do
esac
done
echo ${arm_path}
echo ${converter_path}
mkdir train
arm64_path=${release_path}/android_aarch64
mv ${arm64_path}/*runtime-*train* ./train

View File

@ -33,10 +33,10 @@ namespace mindspore {
namespace lite {
namespace quant {
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {
schema::PrimitiveType_Mul, schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection};
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D};
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {schema::PrimitiveType_MatMul,
schema::PrimitiveType_FullConnection};
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}