forked from mindspore-Ecosystem/mindspore
support infer datatype and format when shape infer fail
This commit is contained in:
parent
82e8884eb5
commit
2f36b91ade
|
@ -43,6 +43,11 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
|
|||
MS_LOG(ERROR) << "input size" << inputs.size() << " is error!";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
|
||||
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
|
||||
|
@ -53,9 +58,8 @@ int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
|
|||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -55,6 +55,12 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
}
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto argmax_prim = this->primitive->value_as_ArgMax();
|
||||
std::vector<int> output_shape(input->shape());
|
||||
auto input_shape_size = input->shape().size();
|
||||
|
@ -68,9 +74,8 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
} else {
|
||||
output_shape[axis] = argmax_prim->topK();
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -55,6 +55,11 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto argmin_prim = this->primitive->value_as_ArgMin();
|
||||
auto input_shape_size = input->shape().size();
|
||||
int axis = argmin_prim->axis() < 0 ? argmin_prim->axis() + input_shape_size : argmin_prim->axis();
|
||||
|
@ -68,9 +73,8 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
} else {
|
||||
output_shape[axis] = argmin_prim->topK();
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -46,6 +46,11 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
|
|||
return 1;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int32_t> dst_shape(this->primitive->value_as_BroadcastTo()->dst_shape()->begin(),
|
||||
this->primitive->value_as_BroadcastTo()->dst_shape()->end());
|
||||
auto input_shape = input->shape();
|
||||
|
@ -72,10 +77,8 @@ int BroadcastTo::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vec
|
|||
shape[i] = dst_shape[i];
|
||||
--input_shape_index;
|
||||
}
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_shape(shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,8 +44,14 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_LOG(ERROR) << "tensor number is error.";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
auto cast_prim = this->primitive->value_as_Cast();
|
||||
MS_ASSERT(cast_prim != nullptr);
|
||||
output->set_data_type(static_cast<TypeId>(cast_prim->dstT()));
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (input->data_type() != cast_prim->srcT()) {
|
||||
MS_LOG(ERROR) << "input dataType is error";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
|
@ -54,13 +60,8 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Invalid output datatype " << cast_prim->dstT();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(TypeId::kNumberTypeFloat32);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -50,16 +50,19 @@ int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto in_tensor = inputs_.front();
|
||||
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
|
||||
auto out_tensor = outputs_.front();
|
||||
out_tensor->set_data_type(kNumberTypeFloat32);
|
||||
out_tensor->SetFormat(in_tensor->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_data = reinterpret_cast<int *>(in_tensor->Data());
|
||||
int size = in_tensor->ElementsNum();
|
||||
std::vector<int> out_shape(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
out_shape[i] = in_data[i];
|
||||
}
|
||||
out_tensor->set_shape(out_shape);
|
||||
out_tensor->set_data_type(kNumberTypeFloat32);
|
||||
out_tensor->SetFormat(in_tensor->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -46,9 +46,12 @@ int Crop::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::T
|
|||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
outputs[0]->set_shape(inputs[1]->shape());
|
||||
outputs[0]->SetFormat(inputs[0]->GetFormat());
|
||||
outputs[0]->set_data_type(inputs[0]->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
outputs[0]->set_shape(inputs[1]->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -103,7 +103,11 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
|
|||
MS_ASSERT(weight != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
int32_t input_h = input->Height();
|
||||
int32_t input_w = input->Width();
|
||||
|
||||
|
@ -138,8 +142,6 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
|
|||
|
||||
std::vector<int> out_shape = {output_n, output_h, output_w, output_c};
|
||||
output->set_shape(out_shape);
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -126,7 +126,11 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
MS_ASSERT(weight != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_shape = input->shape();
|
||||
int input_h = in_shape.at(1);
|
||||
int input_w = in_shape.at(2);
|
||||
|
@ -155,8 +159,6 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -50,6 +50,11 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
|
@ -68,10 +73,7 @@ int DepthToSpace::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
output_shape[NHWC_W] = input_shape[NHWC_W] * block_size;
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C] / (block_size * block_size);
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -120,7 +120,11 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
MS_ASSERT(weight != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_shape = input->shape();
|
||||
int input_h = in_shape.at(1);
|
||||
int input_w = in_shape.at(2);
|
||||
|
@ -158,8 +162,6 @@ int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
return 0;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -46,6 +46,12 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
|
|||
MS_ASSERT(ids != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->SetFormat(params_->GetFormat());
|
||||
output->set_data_type(params_->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto embedding_shape = params_->shape();
|
||||
embedding_shape.erase(embedding_shape.begin());
|
||||
std::vector<int> output_shape(ids->shape());
|
||||
|
@ -61,7 +67,6 @@ int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vect
|
|||
}
|
||||
}
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(params_->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -42,6 +42,11 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
if (outputs_.size() != kSingleNum) {
|
||||
MS_LOG(ERROR) << "output size is invalid";
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto expand_dims_prim = this->primitive->value_as_ExpandDims();
|
||||
int dim = expand_dims_prim->dim();
|
||||
if (dim < 0) {
|
||||
|
@ -54,8 +59,6 @@ int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
auto out_shape = input->shape();
|
||||
out_shape.insert(out_shape.begin() + dim, 1, 1);
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -45,6 +45,11 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto fill_prim = this->primitive->value_as_Fill();
|
||||
if (fill_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Fill primitive is null!";
|
||||
|
@ -53,8 +58,6 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
std::vector<int> output_shape;
|
||||
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -31,6 +31,13 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape(2);
|
||||
output_shape[0] = input_shape[0];
|
||||
|
@ -39,8 +46,6 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
output_shape[1] *= input_shape[i];
|
||||
}
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -51,7 +51,11 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
MS_ASSERT(input1 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
|
||||
MS_LOG(ERROR) << "Input tensors num error";
|
||||
return 1;
|
||||
|
@ -78,8 +82,6 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
|||
out_shape.resize(GetAxis() + 1);
|
||||
out_shape[GetAxis()] = input1->shape()[0];
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -46,6 +46,12 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
|
|||
MS_ASSERT(indices != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_shape = input->shape();
|
||||
int in_rank = in_shape.size();
|
||||
auto indices_shape = indices->shape();
|
||||
|
@ -63,8 +69,6 @@ int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
|
|||
out_shape.emplace_back(in_shape[i]);
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -44,6 +44,14 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_ASSERT(input0 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
for (int i = 0; i < kLstmOutputNum; i++) {
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<int> in_shape = input->shape();
|
||||
std::vector<int> w_shape = weight_i->shape(); // layer, hidden_size * 4, input_size
|
||||
if (in_shape.size() != 3 || w_shape.size() != 3) {
|
||||
|
@ -65,10 +73,7 @@ int Lstm::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
state_shape[2] = hidden_size;
|
||||
outputs_[1]->set_shape(state_shape);
|
||||
outputs_[2]->set_shape(state_shape);
|
||||
for (int i = 0; i < kLstmOutputNum; i++) {
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -43,6 +43,13 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_ASSERT(input1 != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::vector<int> a_shape = input0->shape();
|
||||
std::vector<int> b_shape = input1->shape();
|
||||
if (a_shape.size() < 2 || b_shape.size() < 2) {
|
||||
|
@ -65,8 +72,6 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
std::vector<int> c_shape(a_shape);
|
||||
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
|
||||
output->set_shape(c_shape);
|
||||
output->set_data_type(input0->data_type());
|
||||
output->SetFormat(input0->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -50,6 +50,11 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
if (input == nullptr || output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (this->primitive == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -88,8 +93,6 @@ int Mean::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
}
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -25,6 +25,11 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> nchw_shape = input->shape();
|
||||
if (nchw_shape.size() != 4) {
|
||||
output->set_shape(nchw_shape);
|
||||
|
@ -36,8 +41,6 @@ int Nchw2Nhwc::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
|
|||
nhwc_shape[NHWC_C] = nchw_shape[NCHW_C];
|
||||
output->set_shape(nhwc_shape);
|
||||
}
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -25,6 +25,11 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->SetFormat(schema::Format_NCHW);
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> nhwc_shape = input->shape();
|
||||
if (nhwc_shape.size() != 4) {
|
||||
output->set_shape(nhwc_shape);
|
||||
|
@ -36,8 +41,6 @@ int Nhwc2Nchw::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
|
|||
nchw_shape[NCHW_W] = nhwc_shape[NHWC_W];
|
||||
output->set_shape(nchw_shape);
|
||||
}
|
||||
output->SetFormat(schema::Format_NCHW);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -56,6 +56,19 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
|
|||
if (input == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto on_value = inputs.at(2);
|
||||
if (on_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_data_type(on_value->data_type());
|
||||
output->SetFormat(on_value->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
const auto input_shape = input->shape();
|
||||
int input_rank = static_cast<int>(input_shape.size());
|
||||
if (axis < 0) {
|
||||
|
@ -63,17 +76,7 @@ int OneHot::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor:
|
|||
}
|
||||
std::vector<int> output_shape(input_shape);
|
||||
output_shape.insert(output_shape.cbegin() + axis, *depth);
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_shape(output_shape);
|
||||
auto on_value = inputs.at(2);
|
||||
if (on_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_data_type(on_value->data_type());
|
||||
output->SetFormat(on_value->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -61,6 +61,15 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
|
|||
if (input == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
output->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape;
|
||||
MS_ASSERT(input->shape().size() <= kInputRank);
|
||||
|
@ -69,13 +78,8 @@ int Pad::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Te
|
|||
auto shape = input_shape[i] + (*paddings)[2 * paddings_index] + (*paddings)[2 * paddings_index + 1];
|
||||
output_shape.push_back(shape);
|
||||
}
|
||||
auto output = outputs.front();
|
||||
if (output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -95,6 +95,11 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
int input_h = input->shape().at(1);
|
||||
int input_w = input->shape().at(2);
|
||||
auto pooling_prim = this->primitive->value_as_Pooling();
|
||||
|
@ -137,9 +142,6 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
input_shape.at(1) = output_h;
|
||||
input_shape.at(2) = output_w;
|
||||
output->set_shape(input_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
// todo: temp fix
|
||||
output->SetFormat(schema::Format_NHWC);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -49,15 +49,19 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
}
|
||||
auto output_tensor = outputs[0];
|
||||
MS_ASSERT(output_tensor != nullptr);
|
||||
output_tensor->set_data_type(x_tensor->data_type());
|
||||
output_tensor->SetFormat(x_tensor->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
if (exp_tensor != nullptr) {
|
||||
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
}
|
||||
output_tensor->SetFormat(x_tensor->GetFormat());
|
||||
|
||||
output_tensor->set_shape(x_tensor->shape());
|
||||
output_tensor->set_data_type(x_tensor->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -99,6 +99,15 @@ constexpr int kPriorBoxC = 2;
|
|||
int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
|
||||
auto param = this->primitive->value_as_PriorBox();
|
||||
MS_ASSERT(param != nullptr);
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(kNumberTypeFloat32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<float> different_aspect_ratios{1.0f};
|
||||
auto aspect_ratios = param->aspect_ratios();
|
||||
MS_ASSERT(aspect_ratios != nullptr);
|
||||
|
@ -114,15 +123,9 @@ int PriorBox::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
|
|||
}
|
||||
}
|
||||
int32_t num_priors_box = param->min_sizes()->size() * different_aspect_ratios.size() + param->max_sizes()->size();
|
||||
auto input = inputs_.at(0);
|
||||
MS_ASSERT(input != nullptr);
|
||||
int32_t h = input->Height() * input->Width() * num_priors_box * kPriorBoxPoints;
|
||||
std::vector<int> output_shape{kPriorBoxN, h, kPriorBoxW, kPriorBoxC};
|
||||
auto output = outputs_.at(0);
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(kNumberTypeFloat32);
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -40,11 +40,14 @@ int QuantDTypeCast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
auto param = primitive->value_as_QuantDTypeCast();
|
||||
MS_ASSERT(input->data_type() == param->srcT);
|
||||
output->set_data_type(static_cast<TypeId>(param->dstT()));
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -50,12 +50,18 @@ int Range::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
MS_ASSERT(output != nullptr);
|
||||
auto range_prim = this->primitive->value_as_Range();
|
||||
MS_ASSERT(range_prim != nullptr);
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int shape_size = std::ceil(static_cast<float>(range_prim->limit() - range_prim->start()) / range_prim->delta());
|
||||
std::vector<int> in_shape(1);
|
||||
in_shape.push_back(shape_size);
|
||||
output->set_shape(in_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -25,10 +25,13 @@ int Rank::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
std::vector<int> in_shape(1, 1);
|
||||
output->set_shape(in_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> in_shape(1, 1);
|
||||
output->set_shape(in_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -66,6 +66,11 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
|
|||
if (output == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto new_height = GetNewHeight();
|
||||
auto new_width = GetNewWidth();
|
||||
|
||||
|
@ -75,10 +80,8 @@ int Resize::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<
|
|||
output_shape.push_back(new_width);
|
||||
output_shape.push_back(input->Channel());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,9 +52,13 @@ int ReverseSequence::InferShape(std::vector<tensor::Tensor *> inputs, std::vecto
|
|||
auto output = outputs.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_shape(input->shape());
|
||||
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -56,6 +56,11 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
if (output == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto ROIPooling = this->primitive->value_as_ROIPooling();
|
||||
auto new_h = ROIPooling->pooledH();
|
||||
auto new_w = ROIPooling->pooledW();
|
||||
|
@ -66,8 +71,6 @@ int ROIPooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<te
|
|||
output_shape.push_back(new_w);
|
||||
output_shape.push_back(input->Channel());
|
||||
output->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -51,11 +51,14 @@ int ScatterND::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
output->set_data_type(update->data_type());
|
||||
output->SetFormat(update->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto shape_data = reinterpret_cast<int *>(shape->Data());
|
||||
std::vector<int> out_shape(shape_data, shape_data + shape->DataSize());
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(update->data_type());
|
||||
output->SetFormat(update->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -63,6 +63,11 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
MS_LOG(ERROR) << "space_to_batch only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
|
@ -106,8 +111,7 @@ int SpaceToBatch::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
output_shape[NHWC_W] = input_shape[NHWC_W] / block_sizes_[NHWC_H];
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C];
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,11 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
MS_LOG(ERROR) << "space_to_depth only support NHWC now!";
|
||||
return 1;
|
||||
}
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
if (input_shape.size() != kDimension_4d) {
|
||||
MS_LOG(ERROR) << "input shape dimension size should == " << kDimension_4d;
|
||||
|
@ -69,8 +74,7 @@ int SpaceToDepth::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
output_shape[NHWC_W] = input_shape[NHWC_W] / block_size;
|
||||
output_shape[NHWC_C] = input_shape[NHWC_C] * (block_size * block_size);
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
return 0;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,6 +66,13 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < number_split; ++i) {
|
||||
outputs_[i]->set_data_type(input->data_type());
|
||||
outputs_[i]->SetFormat(input->GetFormat());
|
||||
}
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
int split_dim = spilt_prim->splitDim();
|
||||
std::vector<int> input_shape = input->shape();
|
||||
std::vector<int> size_split;
|
||||
|
|
|
@ -48,6 +48,11 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
return -1;
|
||||
}
|
||||
auto *in_tensor = inputs_.front();
|
||||
outputs_.front()->set_data_type(in_tensor->data_type());
|
||||
outputs_.front()->SetFormat(in_tensor->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto in_shape = in_tensor->shape();
|
||||
std::vector<int> out_shape;
|
||||
// todo: getAxis
|
||||
|
@ -77,8 +82,6 @@ int Squeeze::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
|
|||
}
|
||||
}
|
||||
outputs_.front()->set_shape(out_shape);
|
||||
outputs_.front()->set_data_type(in_tensor->data_type());
|
||||
outputs_.front()->SetFormat(in_tensor->GetFormat());
|
||||
return 0;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -56,6 +56,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
auto stack_prim = this->primitive->value_as_Stack();
|
||||
std::vector<int32_t> output_shape = input_shape;
|
||||
|
@ -84,8 +89,6 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
}
|
||||
output_shape.insert(output_shape.begin() + axis, inputs.size());
|
||||
outputs[0]->set_shape(output_shape);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -164,6 +164,11 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs.front()->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> output_shape;
|
||||
|
@ -214,8 +219,6 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|||
output_shape = ApplyShrinkMask(output_shape);
|
||||
|
||||
outputs.front()->set_shape(output_shape);
|
||||
outputs.front()->set_data_type(input->data_type());
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,11 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto tile_prim = this->primitive->value_as_Tile();
|
||||
MS_ASSERT(tile_prim != nullptr);
|
||||
std::vector<int> out_shape;
|
||||
|
@ -49,9 +54,8 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
int tmp = input->shape()[i] * multiples[i];
|
||||
out_shape.push_back(tmp);
|
||||
}
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
output->set_shape(out_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -46,16 +46,19 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
MS_ASSERT(output0 != nullptr);
|
||||
auto output1 = outputs_.at(1);
|
||||
MS_ASSERT(output1 != nullptr);
|
||||
output0->set_data_type(input->data_type());
|
||||
output0->SetFormat(input->GetFormat());
|
||||
output1->set_data_type(kNumberTypeInt32);
|
||||
output1->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto topk_prim = this->primitive->value_as_TopK();
|
||||
MS_ASSERT(topk_prim != nullptr);
|
||||
auto out_shape = input->shape();
|
||||
out_shape[out_shape.size() - 1] = topk_prim->k();
|
||||
output0->set_shape(out_shape);
|
||||
output0->set_data_type(input->data_type());
|
||||
output0->SetFormat(input->GetFormat());
|
||||
output1->set_shape(out_shape);
|
||||
output1->set_data_type(kNumberTypeInt32);
|
||||
output1->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -42,12 +42,15 @@ int Unique::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
MS_ASSERT(output0 != nullptr);
|
||||
auto &output1 = outputs_.at(1);
|
||||
MS_ASSERT(output1 != nullptr);
|
||||
output0->set_shape(input->shape());
|
||||
output0->set_data_type(input->data_type());
|
||||
output1->set_shape(input->shape());
|
||||
output1->set_data_type(kNumberTypeInt32);
|
||||
output1->SetFormat(input->GetFormat());
|
||||
output0->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output0->set_shape(input->shape());
|
||||
output1->set_shape(input->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -44,6 +44,14 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
|
|||
MS_LOG(ERROR) << "Invalid axis " << prim->axis();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
for (auto &out : outputs) {
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_data_type(input->data_type());
|
||||
out->SetFormat(input->GetFormat());
|
||||
}
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> output_shape;
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
if (i != axis) {
|
||||
|
@ -53,8 +61,6 @@ int Unstack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor
|
|||
for (auto &out : outputs) {
|
||||
MS_ASSERT(out != nullptr);
|
||||
out->set_shape(output_shape);
|
||||
out->set_data_type(input->data_type());
|
||||
out->SetFormat(input->GetFormat());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -53,6 +53,11 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
auto input0 = inputs_.at(0);
|
||||
auto input1 = inputs_.at(1);
|
||||
auto input2 = inputs_.at(2);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
int num = input0->ElementsNum();
|
||||
int num1 = input1->ElementsNum();
|
||||
int num2 = input2->ElementsNum();
|
||||
|
@ -85,8 +90,6 @@ int Where::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
|
|||
auto output_shape = shape_tmp;
|
||||
output_shape[axisout] = nummax;
|
||||
outputs_[0]->set_shape(output_shape);
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -29,10 +29,12 @@ int ZerosLike::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vect
|
|||
<< ", output size: " << outputs_.size();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
output->set_data_type(input->data_type());
|
||||
output->SetFormat(input->GetFormat());
|
||||
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
output->set_shape(input->shape());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -18,15 +18,29 @@
|
|||
#include <float.h>
|
||||
|
||||
int ArgCompareAscFp32(const void *a, const void *b) {
|
||||
return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_;
|
||||
float a_value = ((ArgElement *)a)->data_.f_data_;
|
||||
float b_value = ((ArgElement *)b)->data_.f_data_;
|
||||
if (b_value > a_value) {
|
||||
return -1;
|
||||
}
|
||||
if (b_value < a_value) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ArgCompareDescFp32(const void *a, const void *b) {
|
||||
// cmp funtion of qsort must return int type
|
||||
auto b_value = ((ArgElement *)b)->data_.f_data_;
|
||||
auto a_value = ((ArgElement *)a)->data_.f_data_;
|
||||
int res = b_value > a_value ? 1 : -1;
|
||||
return res;
|
||||
float b_value = ((ArgElement *)b)->data_.f_data_;
|
||||
float a_value = ((ArgElement *)a)->data_.f_data_;
|
||||
if (b_value > a_value) {
|
||||
return 1;
|
||||
}
|
||||
if (b_value < a_value) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ArgMaxDim0OutValue(const float *input, float *output, const int *in_shape, ArgMinMaxParameter *param) {
|
||||
|
|
Loading…
Reference in New Issue