forked from OSSInnovation/mindspore
!6841 [MSLITE][Develop] remove mul op weight quant
Merge pull request !6841 from yangruoqi713/lite
This commit is contained in:
commit
6300fa822f
|
@ -192,41 +192,19 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::Tensor *>
|
||||||
MS_LOG(ERROR) << "opParameter is nullptr";
|
MS_LOG(ERROR) << "opParameter is nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
|
||||||
auto *restore_data = weight_tensor->MutableData();
|
|
||||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
||||||
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
|
||||||
if (dequant_weight == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
weight_tensor->SetData(dequant_weight);
|
|
||||||
}
|
|
||||||
auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto *kernel = new (std::nothrow) ScaleCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "New kernel fails.";
|
MS_LOG(ERROR) << "New kernel fails.";
|
||||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
||||||
weight_tensor->FreeData();
|
|
||||||
weight_tensor->SetData(restore_data);
|
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
||||||
weight_tensor->FreeData();
|
|
||||||
weight_tensor->SetData(restore_data);
|
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
||||||
weight_tensor->FreeData();
|
|
||||||
weight_tensor->SetData(restore_data);
|
|
||||||
}
|
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -617,17 +617,9 @@ basepath=$(pwd)
|
||||||
echo ${basepath}
|
echo ${basepath}
|
||||||
#set -e
|
#set -e
|
||||||
|
|
||||||
# Example:sh run_benchmark_nets.sh -a /home/temp_test -c /home/temp_test -r /home/temp_test -m /home/temp_test/models -d "8KE5T19620002408"
|
# Example:sh run_benchmark_nets.sh -r /home/temp_test -m /home/temp_test/models -d "8KE5T19620002408"
|
||||||
while getopts "a:c:r:m:d:" opt; do
|
while getopts "r:m:d:" opt; do
|
||||||
case ${opt} in
|
case ${opt} in
|
||||||
a)
|
|
||||||
arm_path=${OPTARG}
|
|
||||||
echo "arm_path is ${OPTARG}"
|
|
||||||
;;
|
|
||||||
c)
|
|
||||||
converter_path=${OPTARG}
|
|
||||||
echo "converter_path is ${OPTARG}"
|
|
||||||
;;
|
|
||||||
r)
|
r)
|
||||||
release_path=${OPTARG}
|
release_path=${OPTARG}
|
||||||
echo "release_path is ${OPTARG}"
|
echo "release_path is ${OPTARG}"
|
||||||
|
@ -646,9 +638,6 @@ while getopts "a:c:r:m:d:" opt; do
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
echo ${arm_path}
|
|
||||||
echo ${converter_path}
|
|
||||||
|
|
||||||
mkdir train
|
mkdir train
|
||||||
arm64_path=${release_path}/android_aarch64
|
arm64_path=${release_path}/android_aarch64
|
||||||
mv ${arm64_path}/*runtime-*train* ./train
|
mv ${arm64_path}/*runtime-*train* ./train
|
||||||
|
|
|
@ -33,10 +33,10 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
namespace quant {
|
namespace quant {
|
||||||
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
|
const std::vector<schema::PrimitiveType> QuantStrategy::conv_types = {
|
||||||
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Conv2D,
|
||||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
|
schema::PrimitiveType_DepthwiseConv2D};
|
||||||
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {
|
const std::vector<schema::PrimitiveType> QuantStrategy::mul_types = {schema::PrimitiveType_MatMul,
|
||||||
schema::PrimitiveType_Mul, schema::PrimitiveType_MatMul, schema::PrimitiveType_FullConnection};
|
schema::PrimitiveType_FullConnection};
|
||||||
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
|
QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThreshold)
|
||||||
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
|
: mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue