forked from mindspore-Ecosystem/mindspore
!4678 change ops getter
Merge pull request !4678 from yeyunpeng2020/master_cops_3
This commit is contained in:
commit
2cbb280bea
|
@ -61,18 +61,17 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto argmax_prim = this->primitive->value_as_ArgMax();
|
||||
std::vector<int> output_shape(input->shape());
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = argmax_prim->axis() < 0 ? argmax_prim->axis() + input_shape_size : argmax_prim->axis();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
|
||||
if (axis >= input_shape_size || axis < 0) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << argmax_prim->axis() << ", input shape size: " << input_shape_size;
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (argmax_prim->topK() == 1 && !argmax_prim->keepDims()) {
|
||||
if (GetTopK() == 1 && !GetKeepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else {
|
||||
output_shape[axis] = argmax_prim->topK();
|
||||
output_shape[axis] = GetTopK();
|
||||
}
|
||||
|
||||
output->set_shape(output_shape);
|
||||
|
|
|
@ -46,7 +46,7 @@ void ArgMin::SetKeepDims(bool keep_dims) {}
|
|||
void ArgMin::SetAxisType(int axis_type) {}
|
||||
#endif
|
||||
|
||||
int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
@ -60,18 +60,17 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto argmin_prim = this->primitive->value_as_ArgMin();
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape_size : GetAxis();
|
||||
if (axis >= input_shape_size || axis < 0) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << argmin_prim->axis() << ", input shape size: " << input_shape_size;
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis() << ", input shape size: " << input_shape_size;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int> output_shape(input->shape());
|
||||
if (argmin_prim->topK() == 1 && !argmin_prim->keepDims()) {
|
||||
if (GetTopK() == 1 && !GetKeepDims()) {
|
||||
output_shape.erase(output_shape.begin() + axis);
|
||||
} else {
|
||||
output_shape[axis] = argmin_prim->topK();
|
||||
output_shape[axis] = GetTopK();
|
||||
}
|
||||
|
||||
output->set_shape(output_shape);
|
||||
|
|
|
@ -39,11 +39,10 @@ constexpr int kBroadcastToInputNum = 1;
|
|||
constexpr int kBroadcastToOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
int BroadcastTo::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
|
||||
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
@ -51,27 +50,26 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
|
||||
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
|
||||
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end());
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> shape(dst_shape.size());
|
||||
int input_shape_index = input_shape.size() - 1;
|
||||
if (input_shape.size() > dst_shape.size()) {
|
||||
MS_LOG(ERROR) << "input shape size " << input_shape.size() << " should <= broadcast to shape size "
|
||||
<< dst_shape.size() << "!";
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
for (int i = dst_shape.size() - 1; i >= 0; --i) {
|
||||
if (dst_shape[i] < 0) {
|
||||
MS_LOG(ERROR) << "shape[" << i << "] = " << dst_shape[i] << " ] should be > 0!";
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (input_shape_index >= 0) {
|
||||
auto dim = input_shape[input_shape_index];
|
||||
if (dim != dst_shape[i] && dim != 1) {
|
||||
MS_LOG(ERROR) << "Invalid broadcast shape!";
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
shape[i] = dst_shape[i];
|
||||
|
|
|
@ -45,14 +45,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
auto cast_prim = this->primitive->value_as_Cast();
|
||||
|
||||
MS_ASSERT(cast_prim != nullptr);
|
||||
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
|
||||
output->set_data_type(static_cast<TypeId>(GetDstT()));
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (input->data_type() != cast_prim->srcT()) {
|
||||
if (input->data_type() != GetSrcT()) {
|
||||
MS_LOG(ERROR) << "input dataType is error";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
|
|
@ -55,10 +55,10 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto concat_prim = this->primitive->value_as_Concat();
|
||||
|
||||
MS_ASSERT(concat_prim != nullptr);
|
||||
auto input0_shape = inputs_.at(0)->shape();
|
||||
int axis = concat_prim->axis() < 0 ? concat_prim->axis() + input0_shape.size() : concat_prim->axis();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis >= input0_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis: " << axis;
|
||||
return RET_PARAM_INVALID;
|
||||
|
|
|
@ -41,7 +41,6 @@ constexpr int kCropOutputNum = 1;
|
|||
constexpr int kCropInputNum = 2;
|
||||
} // namespace
|
||||
int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
|
||||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kCropOutputNum || inputs.size() != kCropInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
|
|
|
@ -139,7 +139,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
|
|||
} else {
|
||||
MS_LOG(ERROR) << "unsupported pad mode for deconv";
|
||||
}
|
||||
|
||||
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
|
||||
output->set_shape(out_shape);
|
||||
return 0;
|
||||
|
|
|
@ -154,7 +154,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
out_shape.at(2) = output_w;
|
||||
if (GetChannelMultiplier() * input_channel != weight->shape()[0]) {
|
||||
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
|
||||
return 1;
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
|
||||
|
||||
|
|
|
@ -42,13 +42,13 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
MS_ASSERT(this->primitive != nullptr);
|
||||
if (outputs.size() != kDepthToSpaceOutputNum || inputs.size() != kDepthToSpaceInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
if (input->GetFormat() != schema::Format_NHWC) {
|
||||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return 1;
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
@ -58,14 +58,14 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int32_t block_size = GetBlockSize();
|
||||
if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) {
|
||||
MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size("
|
||||
<< block_size << ") * block_size)!";
|
||||
return 1;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
output_shape[NHWC_N] = input_shape[NHWC_N];
|
||||
|
|
|
@ -47,8 +47,7 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto expand_dims_prim = this->primitive->value_as_ExpandDims();
|
||||
int dim = expand_dims_prim->dim();
|
||||
int dim = GetDim();
|
||||
if (dim < 0) {
|
||||
dim += input->shape().size() + 1;
|
||||
}
|
||||
|
|
|
@ -58,10 +58,10 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto gather_prim = this->primitive->value_as_Gather();
|
||||
|
||||
MS_ASSERT(gather_prim != nullptr);
|
||||
int axis = gather_prim->axis();
|
||||
int batch_dims = gather_prim->batchDims();
|
||||
int axis = GetAxis();
|
||||
int batch_dims = GetBatchDims();
|
||||
if (axis < 0) {
|
||||
axis += input->shape().size();
|
||||
}
|
||||
|
|
|
@ -58,18 +58,18 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_LOG(ERROR) << "OpLstm input dims should be 3.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto lstm_prim = this->primitive->value_as_Lstm();
|
||||
|
||||
int hidden_size = w_shape[1] / 4;
|
||||
// set output
|
||||
std::vector<int> out_shape(in_shape);
|
||||
out_shape[2] = hidden_size;
|
||||
if (lstm_prim->bidirection()) {
|
||||
if (GetBidirection()) {
|
||||
out_shape.insert(out_shape.begin() + 1, 2);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
// set hidden state, cell state
|
||||
std::vector<int> state_shape(in_shape);
|
||||
state_shape[0] = lstm_prim->bidirection() ? 2 : 1;
|
||||
state_shape[0] = GetBidirection() ? 2 : 1;
|
||||
state_shape[2] = hidden_size;
|
||||
outputs_[1]->set_shape(state_shape);
|
||||
outputs_[2]->set_shape(state_shape);
|
||||
|
|
|
@ -62,11 +62,11 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
auto matmul_prim = this->primitive->value_as_MatMul();
|
||||
if (matmul_prim->transposeA()) {
|
||||
|
||||
if (GetTransposeA()) {
|
||||
std::swap(a_shape[a_shape.size() - 1], a_shape[a_shape.size() - 2]);
|
||||
}
|
||||
if (matmul_prim->transposeB()) {
|
||||
if (GetTransposeB()) {
|
||||
std::swap(b_shape[b_shape.size() - 1], b_shape[b_shape.size() - 2]);
|
||||
}
|
||||
std::vector<int> c_shape(a_shape);
|
||||
|
|
|
@ -58,12 +58,12 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto mean_prim = this->primitive->value_as_Mean();
|
||||
bool keep_dims = static_cast<bool>(mean_prim->keepDims());
|
||||
|
||||
bool keep_dims = static_cast<bool>(GetKeepDims());
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> out_shape;
|
||||
const auto &axes = mean_prim->axis();
|
||||
auto num_axes = axes->size();
|
||||
const auto &axes = GetAxis();
|
||||
auto num_axes = axes.size();
|
||||
// reduce on all axes
|
||||
if (num_axes == 0) {
|
||||
if (keep_dims) {
|
||||
|
@ -79,7 +79,7 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
bool reduce_axis = false;
|
||||
for (int idx = 0; idx < num_axes; ++idx) {
|
||||
if (static_cast<size_t>((*axes)[idx]) == i) {
|
||||
if (static_cast<size_t>(axes[idx]) == i) {
|
||||
reduce_axis = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -37,11 +37,8 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
|
|||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto one_hot_prim = this->primitive->value_as_OneHot();
|
||||
if (one_hot_prim == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
int axis = one_hot_prim->axis();
|
||||
|
||||
int axis = GetAxis();
|
||||
// indices, depth, on_value, off_value
|
||||
if (inputs.size() != kOneHotInputNum) {
|
||||
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
|
||||
|
|
|
@ -49,14 +49,9 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
|
|||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto pad_prim = this->primitive->value_as_Pad();
|
||||
if (pad_prim == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto paddings = pad_prim->paddings();
|
||||
if (paddings == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto paddings = GetPaddings();
|
||||
|
||||
auto input = inputs.front();
|
||||
if (input == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
|
@ -75,7 +70,7 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
|
|||
MS_ASSERT(input->shape().size() <= kInputRank);
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
auto paddings_index = i + kInputRank - input_shape.size();
|
||||
auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1];
|
||||
auto shape = input_shape[i] + paddings[2 * paddings_index] + paddings[2 * paddings_index + 1];
|
||||
output_shape.push_back(shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,6 @@ constexpr int kPriorBoxW = 1;
|
|||
constexpr int kPriorBoxC = 2;
|
||||
} // namespace
|
||||
int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
auto param = this->primitive->value_as_PriorBox();
|
||||
MS_ASSERT(param != nullptr);
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
|
@ -109,20 +108,20 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
|
|||
return RET_OK;
|
||||
}
|
||||
std::vector<float> different_aspect_ratios{1.0f};
|
||||
auto aspect_ratios = param->aspect_ratios();
|
||||
auto aspect_ratios = GetAspectRatios();
|
||||
MS_ASSERT(aspect_ratios != nullptr);
|
||||
for (auto i = 0; i < aspect_ratios->size(); i++) {
|
||||
float ratio = (*aspect_ratios)[i];
|
||||
for (auto i = 0; i < aspect_ratios.size(); i++) {
|
||||
float ratio = aspect_ratios[i];
|
||||
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),
|
||||
[&](float v) { return abs(ratio - v) < 1e-6; });
|
||||
if (!exist) {
|
||||
different_aspect_ratios.emplace_back(ratio);
|
||||
if (param->flip()) {
|
||||
if (GetFlip()) {
|
||||
different_aspect_ratios.emplace_back(1.0f / ratio);
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size();
|
||||
int32_t num_priors_box = GetMinSizes().size() * different_aspect_ratios.size() + GetMaxSizes().size();
|
||||
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
|
||||
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
|
||||
output->set_shape(output_shape);
|
||||
|
|
|
@ -40,9 +40,8 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
auto param = primitive->value_as_QuantDTypeCast();
|
||||
MS_ASSERT(input->data_type() == param->srcT);
|
||||
output->set_data_type(static_cast<TypeId>(param->dstT()));
|
||||
output->set_data_type(static_cast<TypeId>(GetDstT()));
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
|
|
|
@ -48,7 +48,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
auto range_prim = this->primitive->value_as_Range();
|
||||
|
||||
MS_ASSERT(range_prim != nullptr);
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
|
@ -57,7 +57,7 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta());
|
||||
int shape_size = std::ceil(static_cast<float>(GetLimit() - GetStart()) / GetDelta());
|
||||
std::vector<int> in_shape(1);
|
||||
in_shape.push_back(shape_size);
|
||||
output->set_shape(in_shape);
|
||||
|
|
|
@ -62,12 +62,12 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto reduce_prim = this->primitive->value_as_Reduce();
|
||||
bool keep_dims = static_cast<bool>(reduce_prim->keepDims());
|
||||
|
||||
bool keep_dims = static_cast<bool>(GetKeepDims());
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> out_shape;
|
||||
const auto &axes = reduce_prim->axes();
|
||||
auto num_axes = axes->size();
|
||||
const auto &axes = GetAxes();
|
||||
auto num_axes = axes.size();
|
||||
// reduce on all axes
|
||||
if (num_axes == 0) {
|
||||
if (keep_dims) {
|
||||
|
@ -83,7 +83,7 @@ int Reduce::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
for (size_t i = 0; i < in_shape.size(); i++) {
|
||||
bool reduce_axis = false;
|
||||
for (int idx = 0; idx < num_axes; ++idx) {
|
||||
if (static_cast<size_t>((*axes)[idx]) == i || static_cast<size_t>((*axes)[idx] + in_shape.size()) == i) {
|
||||
if (static_cast<size_t>(axes[idx]) == i || static_cast<size_t>(axes[idx] + in_shape.size()) == i) {
|
||||
reduce_axis = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -61,9 +61,9 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto ROIPooling = this->primitive->value_as_ROIPooling();
|
||||
auto new_h = ROIPooling->pooledH();
|
||||
auto new_w = ROIPooling->pooledW();
|
||||
|
||||
auto new_h = GetPooledH();
|
||||
auto new_w = GetPooledW();
|
||||
auto shape_data = roi->shape();
|
||||
std::vector<int> output_shape;
|
||||
output_shape.push_back(shape_data[0]);
|
||||
|
|
|
@ -55,12 +55,10 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
}
|
||||
auto in_shape = in_tensor->shape();
|
||||
std::vector<int> out_shape;
|
||||
// todo: getAxis
|
||||
auto squeeze_prim = this->primitive->value_as_Squeeze();
|
||||
MS_EXCEPTION_IF_NULL(squeeze_prim);
|
||||
auto axis = squeeze_prim->axis();
|
||||
|
||||
auto axis = GetAxis();
|
||||
std::vector<int> axes_;
|
||||
for (auto iter = axis->begin(); iter != axis->end(); iter++) {
|
||||
for (auto iter = axis.begin(); iter != axis.end(); iter++) {
|
||||
axes_.push_back(*iter);
|
||||
}
|
||||
if (axes_.size() == 0) {
|
||||
|
|
|
@ -62,11 +62,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
auto stack_prim = this->primitive->value_as_Stack();
|
||||
|
||||
std::vector<int32_t> output_shape = input_shape;
|
||||
int axis = stack_prim->axis() < 0 ? stack_prim->axis() + input_shape.size() : stack_prim->axis();
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis > input_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << stack_prim->axis();
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
schema::Format input0_format = input->GetFormat();
|
||||
|
|
|
@ -174,10 +174,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
std::vector<int> output_shape;
|
||||
ndim_ = static_cast<int>(GetBegin().size());
|
||||
|
||||
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
|
||||
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
|
||||
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
|
||||
|
||||
for (int i = 0; i < ndim_; i++) {
|
||||
in_shape_.emplace_back(input_shape.at(i));
|
||||
begins_.emplace_back((GetBegin())[i]);
|
||||
|
|
|
@ -53,10 +53,9 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto topk_prim = this->primitive->value_as_TopK();
|
||||
MS_ASSERT(topk_prim != nullptr);
|
||||
auto out_shape = input->shape();
|
||||
out_shape[out_shape.size() - 1] = topk_prim->k();
|
||||
out_shape[out_shape.size() - 1] = GetK();
|
||||
output0->set_shape(out_shape);
|
||||
output1->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
|
|
|
@ -53,11 +53,11 @@ int Unsqueeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto unsqueeze_prim = this->primitive->value_as_Unsqueeze();
|
||||
auto dims = unsqueeze_prim->axis()->data();
|
||||
|
||||
auto dims = GetAxis().data();
|
||||
auto in_shape = input->shape();
|
||||
auto in_rank = in_shape.size();
|
||||
auto dim_rank = unsqueeze_prim->axis()->size();
|
||||
auto dim_rank = GetAxis().size();
|
||||
std::vector<int> out_shape;
|
||||
if (dim_rank == 0) {
|
||||
for (auto d : in_shape) {
|
||||
|
|
|
@ -38,10 +38,10 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
|
|||
auto input = inputs.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto input_shape = input->shape();
|
||||
auto prim = this->primitive->value_as_Unstack();
|
||||
int axis = prim->axis() < 0 ? prim->axis() + input_shape.size() : prim->axis();
|
||||
|
||||
int axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis();
|
||||
if (axis < 0 || axis >= input_shape.size()) {
|
||||
MS_LOG(ERROR) << "Invalid axis " << prim->axis();
|
||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (auto &out : outputs) {
|
||||
|
|
Loading…
Reference in New Issue