forked from mindspore-Ecosystem/mindspore
!4864 1.Fix bugs of some InferShape. 2.Fix the bug of fc int8
Merge pull request !4864 from zhanyuan/dev
This commit is contained in:
commit
005ddb5580
|
@ -141,8 +141,8 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
pad_u_ = GetPadUp();
|
||||
pad_d_ = GetPadDown();
|
||||
pad_r_ = GetPadRight();
|
||||
output_h = GetStrideH() * (input_h - 1) * GetKernelH() - pad_u_ - pad_d_;
|
||||
output_w = GetStrideW() * (input_w - 1) * GetKernelW() - pad_l_ - pad_r_;
|
||||
output_h = GetStrideH() * (input_h - 1) + GetKernelH() - pad_u_ - pad_d_;
|
||||
output_w = GetStrideW() * (input_w - 1) + GetKernelW() - pad_l_ - pad_r_;
|
||||
if ((output_h + GetPadUp() + GetPadDown() - GetKernelH()) % GetStrideH() != 0) {
|
||||
output_h += (output_h + GetPadLeft() + GetPadRight() - GetKernelH()) % GetStrideH();
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullCo
|
|||
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; }
|
||||
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; }
|
||||
void FullConnection::SetActivationType(int activationType) {
|
||||
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType;
|
||||
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
|
||||
}
|
||||
#else
|
||||
|
||||
|
@ -47,43 +47,58 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input0 = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto input1 = inputs_.at(1);
|
||||
auto input1 = inputs_[1];
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
|
||||
MS_LOG(ERROR) << "Input tensors num error";
|
||||
return 1;
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size())) {
|
||||
if (GetUseAxis() && (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size()))) {
|
||||
MS_LOG(ERROR) << "FullConnection axis invalid";
|
||||
return 1;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int new_k = 1;
|
||||
for (size_t i = GetAxis(); i < input0->shape().size(); ++i) {
|
||||
new_k *= input0->shape().at(i);
|
||||
}
|
||||
if (new_k != input1->shape().at(1)) {
|
||||
MS_LOG(ERROR) << "Input1 size invalid";
|
||||
return 1;
|
||||
if (GetUseAxis()) {
|
||||
for (int i = GetAxis(); i < input0->shape().size(); ++i) {
|
||||
new_k *= input0->shape()[i];
|
||||
}
|
||||
if (new_k != input1->shape()[1]) {
|
||||
MS_LOG(ERROR) << "Input1 size invalid";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
} else {
|
||||
new_k = input1->shape()[1];
|
||||
}
|
||||
if (GetHasBias()) {
|
||||
if (inputs_.at(2)->shape()[0] != input1->shape()[0]) {
|
||||
if (inputs_[2]->shape()[0] != input1->shape()[0]) {
|
||||
MS_LOG(ERROR) << "bias size invalid";
|
||||
return 1;
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
std::vector<int> out_shape{inputs_[0]->shape()};
|
||||
out_shape.resize(GetAxis() + 1);
|
||||
out_shape[GetAxis()] = input1->shape()[0];
|
||||
if (GetUseAxis()) {
|
||||
out_shape.resize(GetAxis() + 1);
|
||||
out_shape[GetAxis()] = input1->shape()[0];
|
||||
} else {
|
||||
int total = 1;
|
||||
for (int i = 0; i < input0->shape().size(); ++i) {
|
||||
total *= input0->shape()[i];
|
||||
}
|
||||
out_shape.resize(2);
|
||||
auto batch_size = total / new_k;
|
||||
out_shape[0] = batch_size;
|
||||
out_shape[1] = input1->shape()[0];
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -91,8 +91,8 @@ int FullconnectionInt8CPUKernel::ReSize() {
|
|||
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
|
||||
&quant_params_.right_shift);
|
||||
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
|
||||
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_max,
|
||||
&quant_params_.out_act_min);
|
||||
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
|
||||
&quant_params_.out_act_max);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue