From 3ca93b9cf267779485a7d8a676f7b85b0cbe816a Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Tue, 18 Aug 2020 21:32:26 +0800 Subject: [PATCH] change ops getter --- mindspore/lite/src/ops/argmax.cc | 9 ++++----- mindspore/lite/src/ops/argmin.cc | 11 +++++------ mindspore/lite/src/ops/broadcast_to.cc | 14 ++++++-------- mindspore/lite/src/ops/cast.cc | 6 +++--- mindspore/lite/src/ops/concat.cc | 4 ++-- mindspore/lite/src/ops/crop.cc | 1 - mindspore/lite/src/ops/deconv2d.cc | 1 - mindspore/lite/src/ops/dedepthwise_conv2d.cc | 2 +- mindspore/lite/src/ops/depth_to_space.cc | 8 ++++---- mindspore/lite/src/ops/expand_dims.cc | 3 +-- mindspore/lite/src/ops/gather.cc | 6 +++--- mindspore/lite/src/ops/lstm.cc | 6 +++--- mindspore/lite/src/ops/matmul.cc | 6 +++--- mindspore/lite/src/ops/mean.cc | 10 +++++----- mindspore/lite/src/ops/one_hot.cc | 7 ++----- mindspore/lite/src/ops/pad.cc | 13 ++++--------- mindspore/lite/src/ops/prior_box.cc | 11 +++++------ mindspore/lite/src/ops/quant_dtype_cast.cc | 3 +-- mindspore/lite/src/ops/range.cc | 4 ++-- mindspore/lite/src/ops/reduce.cc | 10 +++++----- mindspore/lite/src/ops/roi_pooling.cc | 6 +++--- mindspore/lite/src/ops/squeeze.cc | 8 +++----- mindspore/lite/src/ops/stack.cc | 6 +++--- mindspore/lite/src/ops/strided_slice.cc | 4 ---- mindspore/lite/src/ops/topk.cc | 3 +-- mindspore/lite/src/ops/unsqueeze.cc | 6 +++--- mindspore/lite/src/ops/unstack.cc | 6 +++--- 27 files changed, 75 insertions(+), 99 deletions(-) diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index be50f47acbc..703e970ffd4 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -61,18 +61,17 @@ int ArgMax::InferShape(std::vector inputs_, std::vectorprimitive->value_as_ArgMax(); std::vector output_shape(input->shape()); auto input_shape_size = input->shape().size(); - int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis(); + int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis(); if (axis >= input_shape_size || axis < 0) { - MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size; + MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size; return RET_PARAM_INVALID; } - if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) { + if (GetTopK() == 1 && !GetKeepDims()) { output_shape.erase(output_shape.begin() + axis); } else { - output_shape[axis] = argmax_prim->topK(); + output_shape[axis] = GetTopK(); } output->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index 0d9940f0414..349418a0b47 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -46,7 +46,7 @@ void ArgMin::SetKeepDims(bool keep_dims) {} void ArgMin::SetAxisType(int axis_type) {} #endif -int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { +int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive != nullptr); auto input = inputs_.front(); MS_ASSERT(input != nullptr); @@ -60,18 +60,17 @@ int ArgMin::InferShape(std::vector inputs_, std::vectorprimitive->value_as_ArgMin(); auto input_shape_size = input->shape().size(); - int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis(); + int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis(); if (axis >= input_shape_size || axis < 0) { - MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size; + MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size; return RET_PARAM_INVALID; } std::vector output_shape(input->shape()); - if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) { + if (GetTopK() == 1 && !GetKeepDims()) { output_shape.erase(output_shape.begin() + axis); } else { - output_shape[axis] = argmin_prim->topK(); + output_shape[axis] = GetTopK(); } output->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index eb7ef3dc7b5..38b3cc64518 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -39,11 +39,10 @@ constexpr int kBroadcastToInputNum = 1; constexpr int kBroadcastToOutputNum = 1; } // namespace -int BroadcastTo::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); +int BroadcastTo::InferShape(std::vector inputs, std::vector outputs) { if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); - return 1; + return RET_PARAM_INVALID; } auto input = inputs.at(0); outputs[0]->SetFormat(input->GetFormat()); @@ -51,27 +50,26 @@ int BroadcastTo::InferShape(std::vector inputs, std::vec if (!GetInferFlag()) { return RET_OK; } - std::vector dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(), - this->primitive->value_as_BroadcastTo()->dst_shape()->end()); + std::vector dst_shape(GetDstShape().begin(), GetDstShape().end()); auto input_shape = input->shape(); std::vector shape(dst_shape.size()); int input_shape_index = input_shape.size() - 1; if (input_shape.size() > dst_shape.size()) { MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size " << dst_shape.size() << "!"; - return 1; + return RET_PARAM_INVALID; } for (int i = dst_shape.size() - 1; i >= 0; --i) { if (dst_shape[i] < 0) { MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!"; - return 1; + return RET_PARAM_INVALID; } if (input_shape_index >= 0) { auto dim = input_shape[input_shape_index]; if (dim != dst_shape[i] && dim != 1) { MS_LOG(ERROR) << "Invalid broadcast shape!"; - return 1; + return RET_PARAM_INVALID; } } shape[i] = dst_shape[i]; diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index fa7850bccf9..48d167f1a38 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -45,14 +45,14 @@ int Cast::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); - auto cast_prim = this->primitive->value_as_Cast(); + MS_ASSERT(cast_prim != nullptr); - output->set_data_type(static_cast(cast_prim->dstT())); + output->set_data_type(static_cast(GetDstT())); if (!GetInferFlag()) { return RET_OK; } - if (input->data_type() != cast_prim->srcT()) { + if (input->data_type() != GetSrcT()) { MS_LOG(ERROR) << "input dataType is error"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 2899f148259..3e3d7aa2ec4 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -55,10 +55,10 @@ int Concat::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Concat(); + MS_ASSERT(concat_prim != nullptr); auto input0_shape = inputs_.at(0)->shape(); - int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis(); + int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis(); if (axis < 0 || axis >= input0_shape.size()) { MS_LOG(ERROR) << "Invalid axis: " << axis; return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index 71c07885564..7e2c32167c8 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -41,7 +41,6 @@ constexpr int kCropOutputNum = 1; constexpr int kCropInputNum = 2; } // namespace int Crop::InferShape(std::vector inputs, std::vector outputs) { - MS_ASSERT(this->primitive != nullptr); if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 8aaa7d41350..91a5f31a869 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -139,7 +139,6 @@ int DeConv2D::InferShape(std::vector inputs_, std::vecto } else { MS_LOG(ERROR) << "unsupported pad mode for deconv"; } - std::vector out_shape = {output_n, output_h, output_w, output_c}; output->set_shape(out_shape); return 0; diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index c07e521fe55..53881c51a37 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -154,7 +154,7 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, out_shape.at(2) = output_w; if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; - return 1; + return RET_ERROR; } out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index ab0edb0f208..db24a90e3ed 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -42,13 +42,13 @@ int DepthToSpace::InferShape(std::vector inputs, std::ve MS_ASSERT(this->primitive != nullptr); if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); - return 1; + return RET_PARAM_INVALID; } auto input = inputs.at(0); if (input->GetFormat() != schema::Format_NHWC) { MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; - return 1; + return RET_FORMAT_ERR; } outputs[0]->set_data_type(input->data_type()); outputs[0]->SetFormat(input->GetFormat()); @@ -58,14 +58,14 @@ int DepthToSpace::InferShape(std::vector inputs, std::ve auto input_shape = input->shape(); if (input_shape.size() != kDimension_4d) { MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d; - return 1; + return RET_PARAM_INVALID; } int32_t block_size = GetBlockSize(); if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) { MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size(" << block_size << ") * block_size)!"; - return 1; + return RET_PARAM_INVALID; } std::vector output_shape(input_shape.size()); output_shape[NHWC_N] = input_shape[NHWC_N]; diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 0cdff13698b..a05ea724399 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -47,8 +47,7 @@ int ExpandDims::InferShape(std::vector inputs_, std::vectorprimitive->value_as_ExpandDims(); - int dim = expand_dims_prim->dim(); + int dim = GetDim(); if (dim < 0) { dim += input->shape().size() + 1; } diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index a7730db82f9..9f7c5ab4355 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -58,10 +58,10 @@ int Gather::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Gather(); + MS_ASSERT(gather_prim != nullptr); - int axis = gather_prim->axis(); - int batch_dims = gather_prim->batchDims(); + int axis = GetAxis(); + int batch_dims = GetBatchDims(); if (axis < 0) { axis += input->shape().size(); } diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index 7c7ed2391d6..8c6b27f7ba2 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -58,18 +58,18 @@ int Lstm::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Lstm(); + int hidden_size = w_shape[1] / 4; // set output std::vector out_shape(in_shape); out_shape[2] = hidden_size; - if (lstm_prim->bidirection()) { + if (GetBidirection()) { out_shape.insert(out_shape.begin() + 1, 2); } output->set_shape(out_shape); // set hidden state, cell state std::vector state_shape(in_shape); - state_shape[0] = lstm_prim->bidirection() ? 2 : 1; + state_shape[0] = GetBidirection() ? 2 : 1; state_shape[2] = hidden_size; outputs_[1]->set_shape(state_shape); outputs_[2]->set_shape(state_shape); diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index ba19adb38cc..67fc4fee27f 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -62,11 +62,11 @@ int MatMul::InferShape(std::vector inputs_, std::vectorprimitive->value_as_MatMul(); - if (matmul_prim->transposeA()) { + + if (GetTransposeA()) { std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]); } - if (matmul_prim->transposeB()) { + if (GetTransposeB()) { std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]); } std::vector c_shape(a_shape); diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc index a2901b8890e..a68fd1f3338 100644 --- a/mindspore/lite/src/ops/mean.cc +++ b/mindspore/lite/src/ops/mean.cc @@ -58,12 +58,12 @@ int Mean::InferShape(std::vector inputs_, std::vectorprimitive == nullptr) { return RET_NULL_PTR; } - auto mean_prim = this->primitive->value_as_Mean(); - bool keep_dims = static_cast(mean_prim->keepDims()); + + bool keep_dims = static_cast(GetKeepDims()); std::vector in_shape = input->shape(); std::vector out_shape; - const auto &axes = mean_prim->axis(); - auto num_axes = axes->size(); + const auto &axes = GetAxis(); + auto num_axes = axes.size(); // reduce on all axes if (num_axes == 0) { if (keep_dims) { @@ -79,7 +79,7 @@ int Mean::InferShape(std::vector inputs_, std::vector((*axes)[idx]) == i) { + if (static_cast(axes[idx]) == i) { reduce_axis = true; break; } diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index 848f763583e..7398361adea 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -37,11 +37,8 @@ int OneHot::InferShape(std::vector inputs, std::vectorprimitive == nullptr) { return RET_NULL_PTR; } - auto one_hot_prim = this->primitive->value_as_OneHot(); - if (one_hot_prim == nullptr) { - return RET_NULL_PTR; - } - int axis = one_hot_prim->axis(); + + int axis = GetAxis(); // indices, depth, on_value, off_value if (inputs.size() != kOneHotInputNum) { MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index fef616abf62..2d20028b696 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -49,14 +49,9 @@ int Pad::InferShape(std::vector inputs, std::vectorprimitive == nullptr) { return RET_NULL_PTR; } - auto pad_prim = this->primitive->value_as_Pad(); - if (pad_prim == nullptr) { - return RET_NULL_PTR; - } - auto paddings = pad_prim->paddings(); - if (paddings == nullptr) { - return RET_NULL_PTR; - } + + auto paddings = GetPaddings(); + auto input = inputs.front(); if (input == nullptr) { return RET_NULL_PTR; @@ -75,7 +70,7 @@ int Pad::InferShape(std::vector inputs, std::vectorshape().size() <= kInputRank); for (size_t i = 0; i < input_shape.size(); i++) { auto paddings_index = i + kInputRank - input_shape.size(); - auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1]; + auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1]; output_shape.push_back(shape); } diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc index eedb6ec2b74..1b972e18c0e 100644 --- a/mindspore/lite/src/ops/prior_box.cc +++ b/mindspore/lite/src/ops/prior_box.cc @@ -97,7 +97,6 @@ constexpr int kPriorBoxW = 1; constexpr int kPriorBoxC = 2; } // namespace int PriorBox::InferShape(std::vector inputs_, std::vector outputs_) { - auto param = this->primitive->value_as_PriorBox(); MS_ASSERT(param != nullptr); auto input = inputs_.at(0); MS_ASSERT(input != nullptr); @@ -109,20 +108,20 @@ int PriorBox::InferShape(std::vector inputs_, std::vector different_aspect_ratios{1.0f}; - auto aspect_ratios = param->aspect_ratios(); + auto aspect_ratios = GetAspectRatios(); MS_ASSERT(aspect_ratios != nullptr); - for (auto i = 0; i < aspect_ratios->size(); i++) { - float ratio = (*aspect_ratios)[i]; + for (auto i = 0; i < aspect_ratios.size(); i++) { + float ratio = aspect_ratios[i]; bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(), [&](float v) { return abs(ratio - v) < 1e-6; }); if (!exist) { different_aspect_ratios.emplace_back(ratio); - if (param->flip()) { + if (GetFlip()) { different_aspect_ratios.emplace_back(1.0f / ratio); } } } - int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size(); + int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size(); int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints; std::vector output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC}; output->set_shape(output_shape); diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index ddbb1eadd50..f3852021041 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -40,9 +40,8 @@ int QuantDTypeCast::InferShape(std::vector inputs_, std::vecto MS_ASSERT(input != nullptr); auto output = outputs_.front(); MS_ASSERT(output != nullptr); - auto param = primitive->value_as_QuantDTypeCast(); MS_ASSERT(input->data_type() == param->srcT); - output->set_data_type(static_cast(param->dstT())); + output->set_data_type(static_cast(GetDstT())); output->SetFormat(input->GetFormat()); if (!GetInferFlag()) { return RET_OK; diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 08e7f89728a..b246af74bc4 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -48,7 +48,7 @@ int Range::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Range(); + MS_ASSERT(range_prim != nullptr); output->set_data_type(input->data_type()); @@ -57,7 +57,7 @@ int Range::InferShape(std::vector inputs_, std::vector(range_prim->limit() - range_prim->start()) / range_prim->delta()); + int shape_size = std::ceil(static_cast(GetLimit() - GetStart()) / GetDelta()); std::vector in_shape(1); in_shape.push_back(shape_size); output->set_shape(in_shape); diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 652d962ade7..2fa00dd9666 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -62,12 +62,12 @@ int Reduce::InferShape(std::vector inputs_, std::vectorprimitive == nullptr) { return RET_NULL_PTR; } - auto reduce_prim = this->primitive->value_as_Reduce(); - bool keep_dims = static_cast(reduce_prim->keepDims()); + + bool keep_dims = static_cast(GetKeepDims()); std::vector in_shape = input->shape(); std::vector out_shape; - const auto &axes = reduce_prim->axes(); - auto num_axes = axes->size(); + const auto &axes = GetAxes(); + auto num_axes = axes.size(); // reduce on all axes if (num_axes == 0) { if (keep_dims) { @@ -83,7 +83,7 @@ int Reduce::InferShape(std::vector inputs_, std::vector((*axes)[idx]) == i || static_cast((*axes)[idx] + in_shape.size()) == i) { + if (static_cast(axes[idx]) == i || static_cast(axes[idx] + in_shape.size()) == i) { reduce_axis = true; break; } diff --git a/mindspore/lite/src/ops/roi_pooling.cc b/mindspore/lite/src/ops/roi_pooling.cc index f4853425988..afd2dde720a 100644 --- a/mindspore/lite/src/ops/roi_pooling.cc +++ b/mindspore/lite/src/ops/roi_pooling.cc @@ -61,9 +61,9 @@ int ROIPooling::InferShape(std::vector inputs_, std::vectorprimitive->value_as_ROIPooling(); - auto new_h = ROIPooling->pooledH(); - auto new_w = ROIPooling->pooledW(); + + auto new_h = GetPooledH(); + auto new_w = GetPooledW(); auto shape_data = roi->shape(); std::vector output_shape; output_shape.push_back(shape_data[0]); diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index dcc3e4c07e6..6669429a427 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -55,12 +55,10 @@ int Squeeze::InferShape(std::vector inputs_, std::vectorshape(); std::vector out_shape; - // todo: getAxis - auto squeeze_prim = this->primitive->value_as_Squeeze(); - MS_EXCEPTION_IF_NULL(squeeze_prim); - auto axis = squeeze_prim->axis(); + + auto axis = GetAxis(); std::vector axes_; - for (auto iter = axis->begin(); iter != axis->end(); iter++) { + for (auto iter = axis.begin(); iter != axis.end(); iter++) { axes_.push_back(*iter); } if (axes_.size() == 0) { diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index 3ac6c465f14..2134ac03efa 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -62,11 +62,11 @@ int Stack::InferShape(std::vector inputs, std::vectorshape(); - auto stack_prim = this->primitive->value_as_Stack(); + std::vector output_shape = input_shape; - int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis(); + int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis(); if (axis < 0 || axis > input_shape.size()) { - MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis(); + MS_LOG(ERROR) << "Invalid axis " << GetAxis(); return RET_PARAM_INVALID; } schema::Format input0_format = input->GetFormat(); diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 7ec9fd1af38..dc2dda089c4 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -174,10 +174,6 @@ int StridedSlice::InferShape(std::vector inputs, std::ve std::vector output_shape; ndim_ = static_cast(GetBegin().size()); - MS_ASSERT(ndim_ == static_cast(strided_slice_prim->end()->size())); - MS_ASSERT(ndim_ == static_cast(strided_slice_prim->stride()->size())); - MS_ASSERT(ndim_ == static_cast(input_shape.size())); - for (int i = 0; i < ndim_; i++) { in_shape_.emplace_back(input_shape.at(i)); begins_.emplace_back((GetBegin())[i]); diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 03463686d0d..0f3abac581f 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -53,10 +53,9 @@ int TopK::InferShape(std::vector inputs_, std::vectorprimitive->value_as_TopK(); MS_ASSERT(topk_prim != nullptr); auto out_shape = input->shape(); - out_shape[out_shape.size() - 1] = topk_prim->k(); + out_shape[out_shape.size() - 1] = GetK(); output0->set_shape(out_shape); output1->set_shape(out_shape); return RET_OK; diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc index 140368e2e44..ab384aa833a 100644 --- a/mindspore/lite/src/ops/unsqueeze.cc +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -53,11 +53,11 @@ int Unsqueeze::InferShape(std::vector inputs_, std::vectorprimitive->value_as_Unsqueeze(); - auto dims = unsqueeze_prim->axis()->data(); + + auto dims = GetAxis().data(); auto in_shape = input->shape(); auto in_rank = in_shape.size(); - auto dim_rank = unsqueeze_prim->axis()->size(); + auto dim_rank = GetAxis().size(); std::vector out_shape; if (dim_rank == 0) { for (auto d : in_shape) { diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index 490e9285b9a..0b3f737db20 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -38,10 +38,10 @@ int Unstack::InferShape(std::vector inputs, std::vectorshape(); - auto prim = this->primitive->value_as_Unstack(); - int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis(); + + int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis(); if (axis < 0 || axis >= input_shape.size()) { - MS_LOG(ERROR) << "Invalid axis " << prim->axis(); + MS_LOG(ERROR) << "Invalid axis " << GetAxis(); return RET_PARAM_INVALID; } for (auto &out : outputs) {