!9425 [MS][list][x86] add new tensorlist ops

From: @lzkcode
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-08 19:27:55 +08:00 committed by Gitee
commit 6550769104
32 changed files with 1747 additions and 456 deletions

View File

@ -26,6 +26,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/thread_pool.c
${CMAKE_CURRENT_SOURCE_DIR}/tensor.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc

View File

@ -17,7 +17,7 @@
#include "nnacl/tensorlist_parameter.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "src/ops/tensor_list.h"
#include "src/ops/tensorlistfromtensor.h"
namespace mindspore {
namespace lite {
@ -31,8 +31,8 @@ OpParameter *PopulateTensorListFromTensorParameter(const mindspore::lite::Primit
TensorList_param->op_parameter_.type_ = primitive->Type();
auto tensorList =
reinterpret_cast<mindspore::lite::TensorListFromTensor *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
TensorList_param->shape_type_ = tensorList->GetShapeType();
TensorList_param->element_dtype_ = tensorList->GetElementDType();
TensorList_param->shape_type_ = (TypeId)(tensorList->GetShapeType());
TensorList_param->element_dtype_ = (TypeId)(tensorList->GetElementDType());
return reinterpret_cast<OpParameter *>(TensorList_param);
}
Registry TensorListFromTensorParameterRegistry(schema::PrimitiveType_TensorListFromTensor,

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/ops/tensor_list.h"
#include "src/ops/tensorlistgetitem.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/tensorlist_parameter.h"

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/ops/tensor_list.h"
#include "src/ops/tensorlistreserve.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/tensorlist_parameter.h"

View File

@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/tensorlistsetitem.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/tensorlist_parameter.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateTensorListSetItemParameter(const mindspore::lite::PrimitiveC *primitive) {
TensorListParameter *setItem_param = reinterpret_cast<TensorListParameter *>(malloc(sizeof(TensorListParameter)));
if (setItem_param == nullptr) {
MS_LOG(ERROR) << "malloc TensorListParameter failed.";
return nullptr;
}
memset(setItem_param, 0, sizeof(TensorListParameter));
setItem_param->op_parameter_.type_ = primitive->Type();
auto setItem =
reinterpret_cast<mindspore::lite::TensorListSetItem *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
setItem_param->element_dtype_ = setItem->GetElementDType();
return reinterpret_cast<OpParameter *>(setItem_param);
}
Registry TensorListSetItemParameterRegistry(schema::PrimitiveType_TensorListSetItem,
PopulateTensorListSetItemParameter);
} // namespace lite
} // namespace mindspore

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/ops/tensor_list.h"
#include "src/ops/tensorliststack.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/tensorlist_parameter.h"

View File

@ -1,314 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensor_list.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
using mindspore::schema::Format_NC;
namespace mindspore {
namespace lite {
int TensorListFromTensor::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
// inputs0:tensor
// inputs1: element_shape
// outputs0: vector<tensor>.size() dtype
// outputs1: element_shape
// outputs2-n: vector<tensor>
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
std::vector<int> in_shape = input->shape();
int dim0 = in_shape.at(0);
if (dim0 <= 0) {
MS_LOG(ERROR) << "inputs_[0] dim0:" << dim0 << " must greater than 0";
return RET_ERROR;
}
std::vector<int> out_shape(in_shape.begin() + 1, in_shape.end());
int out_vec_size = outputs_.size() - 2;
if (out_vec_size != dim0) {
MS_LOG(ERROR) << "outputs_.size() - 2:" << out_vec_size << "must be equal to dim0:" << dim0;
return RET_ERROR;
}
for (int i = 0; i < dim0; ++i) {
auto output = outputs_.at(i + 2);
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->set_shape(out_shape);
}
auto output = outputs_.at(0); // vector<tensor>.size(), tensorlist.dtype
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeInt);
output->set_shape(std::vector<int>(1, 2)); // one element.value = 2
output = outputs_.at(1); // element_shape tensor
MS_ASSERT(output != nullptr);
output->set_data_type(inputs_.at(1)->data_type());
output->set_format(inputs_.at(1)->format());
output->set_shape(inputs_.at(1)->shape());
return RET_OK;
}
bool TensorListGetItem::IsFullyDefined(const std::vector<int> &shape) const {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape.at(i) < 0) {
return false;
}
}
return true;
}
int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
int in_vec_size = inputs_.size();
auto input0 = inputs_.at(0);
MS_ASSERT(input0 != nullptr);
auto in0_ptr = reinterpret_cast<int *>(input0->data_c());
if (in_vec_size != in0_ptr[0] + 4) {
MS_LOG(ERROR) << "inputs_.size():" << in_vec_size << " must be equal to:" << in0_ptr[0] + 4;
return RET_ERROR;
}
auto get_index = inputs_.at(in0_ptr[0] + 2);
MS_ASSERT(get_index != nullptr);
index_ = reinterpret_cast<int *>(get_index->data_c())[0];
if (index_ < 0 || index_ > in0_ptr[0]) {
MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << in0_ptr[0] << "]";
return RET_ERROR;
}
auto input_index = inputs_.at(index_ + 2);
MS_ASSERT(input_index != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (input_index->data_type() != kTypeUnknown) {
output->set_format(input_index->format());
output->set_data_type(input_index->data_type());
output->set_shape(input_index->shape());
} else {
auto ele_shape_tensor = inputs_.at(in0_ptr[0] + 3);
MS_ASSERT(ele_shape_tensor != nullptr);
auto ele_shape_type = ele_shape_tensor->data_type();
if (ele_shape_type != kNumberTypeInt) {
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
auto shape_ptr = reinterpret_cast<int *>(ele_shape_tensor->data_c());
for (int i = 0; i < ele_shape_tensor->ElementsNum(); ++i) {
element_shape_.push_back(shape_ptr[i]);
}
if (!IsFullyDefined(element_shape_)) {
for (int i = 0; i < in0_ptr[0]; ++i) {
auto input = inputs_.at(i + 2);
if (input->data_type() != kTypeUnknown) {
std::vector<int> tmp = input->shape();
for (size_t j = 0; j < tmp.size(); ++j) {
element_shape_.at(j) = element_shape_.at(j) >= 0 ? element_shape_.at(j) : tmp.at(j);
}
}
}
}
if (!IsFullyDefined(element_shape_)) {
MS_LOG(ERROR) << "ele_shape_tensor Is Not FullyDefined!";
return RET_ERROR;
}
element_dtype_ = GetElementDType();
output->set_data_type(element_dtype_);
output->set_shape(element_shape_);
}
return RET_OK;
}
#ifdef PRIMITIVE_WRITEABLE
TypeId TensorListFromTensor::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListFromTensor()->elementDType);
}
TypeId TensorListFromTensor::GetShapeType() const {
return (TypeId)(this->primitive_->value.AsTensorListFromTensor()->shapeType);
}
TypeId TensorListGetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListGetItem()->elementDType);
}
TypeId TensorListReserve::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListReserve()->elementDType);
}
TypeId TensorListStack::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListStack()->elementDType);
}
int TensorListStack::GetNumElements() const { return this->primitive_->value.AsTensorListStack()->numElements; }
#else
TypeId TensorListFromTensor::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListFromTensor()->elementDType());
}
TypeId TensorListFromTensor::GetShapeType() const {
return (TypeId)(this->primitive_->value_as_TensorListFromTensor()->shapeType());
}
TypeId TensorListGetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListGetItem()->elementDType());
}
TypeId TensorListReserve::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListReserve()->elementDType());
}
TypeId TensorListStack::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListStack()->elementDType());
}
int TensorListStack::GetNumElements() const { return this->primitive_->value_as_TensorListStack()->numElements(); }
#endif
int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
// input0: element_shape_tensor
// input1: num_elements
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto ele_shape_type = input0->data_type();
if (ele_shape_type != kNumberTypeInt) {
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
auto input1 = inputs_.at(1);
MS_ASSERT(input1 != nullptr);
auto num_ele_type = input1->data_type();
if (num_ele_type != kNumberTypeInt) {
MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];
auto out_vec_size = outputs_.size();
if (out_vec_size != (size_t)(num_elements + 2)) {
MS_LOG(ERROR) << "outputs_.size():" << out_vec_size << " must be equal to:" << num_elements + 2;
return RET_ERROR;
}
for (int i = 0; i < num_elements; ++i) {
auto output = outputs_.at(i + 2);
MS_ASSERT(output != nullptr);
output->set_data_type(kTypeUnknown);
output->set_shape(std::vector<int>(1, 0)); // shape = [0]
}
auto output = outputs_.at(0); // vector<tensor>.size(), tensorlist.dtype
MS_ASSERT(output != nullptr);
output->set_data_type(kNumberTypeInt);
output->set_shape(std::vector<int>(1, 2)); // one element.value = 2
output = outputs_.at(1); // element_shape tensor
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->set_format(input0->format());
output->set_shape(input0->shape());
return RET_OK;
}
bool TensorListStack::IsFullyDefined(const std::vector<int> &shape) const {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape.at(i) < 0) {
return false;
}
}
return true;
}
int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
// input0: tensorlist
// input[inputs_.size() - 1]: element_shape
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto input0_ptr = reinterpret_cast<int *>(input0->data_c());
int vec_in_size = inputs_.size();
if (vec_in_size != input0_ptr[0] + 3) {
MS_LOG(ERROR) << "inputs_.size():" << vec_in_size << " must be equal:" << input0_ptr[0] + 3;
return RET_ERROR;
}
auto ele_shape = inputs_.at(input0_ptr[0] + 2); // element shape
MS_ASSERT(ele_shape != nullptr);
auto ele_shape_ptr = reinterpret_cast<int *>(ele_shape->data_c());
for (int i = 0; ele_shape->ElementsNum(); ++i) {
output_shape_.push_back(ele_shape_ptr[i]);
}
std::vector<int> tensorlist_shape;
MS_ASSERT(inputs_.at(1) != nullptr);
auto input1_ptr = reinterpret_cast<int *>(inputs_.at(1)->data_c());
for (int i = 0; i < inputs_.at(1)->ElementsNum(); ++i) {
tensorlist_shape.push_back(input1_ptr[i]);
}
auto status = MergeShape(tensorlist_shape);
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge tensorlist_shape is error!";
return RET_ERROR;
}
if (!IsFullyDefined(output_shape_)) {
MS_LOG(ERROR) << "element_shape Is Not FullyDefined!";
return RET_ERROR;
}
if (!IsFullyDefined(tensorlist_shape)) {
for (int i = 0; i < input0_ptr[0]; ++i) { // get tensorlist every tensor
auto tensor_tmp = inputs_.at(i + 2);
MS_ASSERT(tensor_tmp != nullptr);
if (tensor_tmp->data_type() != kTypeUnknown) {
status = MergeShape(tensor_tmp->shape());
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge inputs_[" << i + 2 << "] is error!";
return RET_ERROR;
}
}
}
}
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_format(Format_NC);
output->set_data_type(static_cast<TypeId>(input0_ptr[1]));
output->set_shape(std::vector<int>(
1, input0_ptr[0] * std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies<int>())));
return RET_OK;
}
int TensorListStack::MergeShape(const std::vector<int> &shape) {
size_t dim0 = shape.size();
size_t dim1 = output_shape_.size();
if (dim1 >= unKnownRank_) {
output_shape_ = shape;
return RET_OK;
}
if (dim1 != dim0) {
MS_LOG(ERROR) << "shape.size():" << dim1 << " must be equal output_shape_.size():" << dim0;
return RET_ERROR;
}
for (size_t i = 0; i < dim0; ++i) {
int dim0_size = shape.at(i);
int dim1_size = output_shape_.at(i);
if (dim0_size >= 0 && dim1_size >= 0 && dim0_size != dim1_size) {
MS_LOG(ERROR) << "shape[" << i << "]:" << dim0_size << " is incompatible with output_shape_[" << i
<< "]:" << dim1_size;
return RET_ERROR;
}
int tmp_size = dim1_size >= 0 ? dim1_size : dim0_size;
output_shape_.at(i) = tmp_size;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,141 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensorlistfromtensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int TensorListFromTensor::GetElementDType() const {
return this->primitive_->value.AsTensorListFromTensor()->elementDType;
}
int TensorListFromTensor::GetShapeType() const { return this->primitive_->value.AsTensorListFromTensor()->shapeType; }
void TensorListFromTensor::SetElementDType(int type) {
this->primitive_->value.AsTensorListFromTensor()->elementDType = type;
}
void TensorListFromTensor::SetShapeType(int type) {
this->primitive_->value.AsTensorListFromTensor()->shapeType = type;
}
int TensorListFromTensor::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TensorListFromTensor;
}
if (this->primitive_->value.type != schema::PrimitiveType_TensorListFromTensor) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TensorListFromTensorT();
if (attr == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
MS_LOG(ERROR) << "new TensorListFromTensorT value failed";
return RET_ERROR;
}
if (prim.GetAttr("elementDType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListFromTensorT's attr elementDType is not set";
return RET_ERROR;
} else {
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
}
if (prim.GetAttr("shapeType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListFromTensorT's attr shapeType is not set";
return RET_ERROR;
} else {
attr->shapeType = CastToInt(prim.GetAttr("shapeType")).front();
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
int TensorListFromTensor::GetElementDType() const {
return this->primitive_->value_as_TensorListFromTensor()->elementDType();
}
int TensorListFromTensor::GetShapeType() const {
return this->primitive_->value_as_TensorListFromTensor()->shapeType();
}
int TensorListFromTensor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_TensorListFromTensor();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_TensorListFromTensor return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateTensorListFromTensor(*fbb, attr->elementDType(), attr->shapeType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListFromTensor, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *TensorListFromTensorCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<TensorListFromTensor>(primitive);
}
Registry TensorListFromTensorRegistry(schema::PrimitiveType_TensorListFromTensor, TensorListFromTensorCreator);
#endif
int TensorListFromTensor::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
auto input0 = inputs_[0];
MS_ASSERT(input0 != nullptr);
std::vector<int> input0_shape = input0->shape();
if (input0_shape.size() < 1) {
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
return RET_ERROR;
}
int dim0 = input0_shape[0];
if (dim0 < 0) {
MS_LOG(ERROR) << "inputs_[0] dim0:" << dim0 << " must greater than or equal to 0";
return RET_ERROR;
}
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto ele_shape_ptr = reinterpret_cast<int *>(input1->data_c());
auto output = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output != nullptr);
// output->set_tensors_data_type(input0->data_type());
std::vector<std::vector<int> > tensor_shape(dim0, std::vector<int>(input0_shape.begin() + 1, input0_shape.end()));
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input1->ElementsNum()));
output->set_shape(std::vector<int>(1, dim0));
output->set_data_type(kObjectTypeTensorType);
output->MallocTensorListData(input0->data_type(), tensor_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/primitive_c.h"
#include "src/tensorlist.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
namespace mindspore {
namespace lite {
class TensorListFromTensor : public PrimitiveC {
public:
TensorListFromTensor() = default;
~TensorListFromTensor() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(TensorListFromTensor, PrimitiveC);
void SetElementDType(int type);
void SetShapeType(int type);
explicit TensorListFromTensor(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetElementDType() const;
int GetShapeType() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_

View File

@ -0,0 +1,171 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensorlistgetitem.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
TypeId TensorListGetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListGetItem()->elementDType);
}
void TensorListGetItem::SetElementDType(int type) {
this->primitive_->value.AsTensorListGetItem()->elementDType = type;
}
int TensorListGetItem::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TensorListGetItem;
}
if (this->primitive_->value.type != schema::PrimitiveType_TensorListGetItem) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TensorListGetItemT();
if (attr == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
MS_LOG(ERROR) << "new TensorListGetItemT value failed";
return RET_ERROR;
}
if (prim.GetAttr("elementDType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListGetItem's attr elementDType is not set";
return RET_ERROR;
} else {
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
TypeId TensorListGetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListGetItem()->elementDType());
}
int TensorListGetItem::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_TensorListGetItem();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_TensorListGetItem return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateTensorListGetItem(*fbb, attr->elementDType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListGetItem, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *TensorListGetItemCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<TensorListGetItem>(primitive);
}
Registry TensorListGetItemRegistry(schema::PrimitiveType_TensorListGetItem, TensorListGetItemCreator);
#endif
bool TensorListGetItem::IsFullyDefined(const std::vector<int> &shape) const {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
return false;
}
}
return true;
}
int TensorListGetItem::MergeShape(const std::vector<int> &tmp) {
if (element_shape_.size() != tmp.size()) {
MS_LOG(ERROR) << "element_shape_.size():" << element_shape_.size() << " must be equal to tmp.size():" << tmp.size();
return RET_ERROR;
}
for (size_t j = 0; j < tmp.size(); ++j) {
if (element_shape_[j] >= 0 && tmp[j] >= 0 && element_shape_[j] != tmp[j]) {
MS_LOG(ERROR) << "element_shape_[" << j << "]:" << element_shape_[j] << " must be equal to tmp[" << j
<< "]:" << tmp[j];
return RET_ERROR;
}
element_shape_[j] = element_shape_[j] >= 0 ? element_shape_[j] : tmp[j];
}
return RET_OK;
}
int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
auto input0 = reinterpret_cast<TensorList *>(inputs_[0]);
auto get_index = inputs_[1];
MS_ASSERT(get_index != nullptr);
if (get_index->ElementsNum() != 1) {
MS_LOG(ERROR) << "get_index->ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
index_ = reinterpret_cast<int *>(get_index->data_c())[0];
if (index_ < 0 || index_ > (input0->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << input0->ElementsNum() - 1 << "]";
return RET_ERROR;
}
auto tensor_index = input0->GetTensorIndex(index_);
MS_ASSERT(tensor_index != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (tensor_index->data_type() != kTypeUnknown) {
output->set_data_type(tensor_index->data_type());
output->set_shape(tensor_index->shape());
} else {
auto input2 = inputs_[2];
auto ele_shape_data = reinterpret_cast<int *>(input2->data_c());
for (int i = 0; i < input2->ElementsNum(); ++i) {
element_shape_.push_back(ele_shape_data[i]);
}
auto status = MergeShape(input0->element_shape());
if (status != RET_OK) {
return RET_ERROR;
}
if (!IsFullyDefined(element_shape_)) {
for (int i = 0; i < input0->ElementsNum(); ++i) {
auto input = input0->GetTensorIndex(i);
MS_ASSERT(input != nullptr);
if (input->data_type() != kTypeUnknown) {
status = MergeShape(input->shape());
if (status != RET_OK) {
return RET_ERROR;
}
}
}
}
if (!IsFullyDefined(element_shape_)) {
MS_LOG(ERROR) << "element_shape_ is not fullyDefined!";
return RET_ERROR;
}
output->set_data_type(GetElementDType());
output->set_shape(element_shape_);
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/primitive_c.h"
#include "src/tensorlist.h"
#include "ir/dtype/type_id.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_
namespace mindspore {
namespace lite {
class TensorListGetItem : public PrimitiveC {
public:
TensorListGetItem() = default;
~TensorListGetItem() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(TensorListGetItem, PrimitiveC);
void SetElementDType(int type);
explicit TensorListGetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
TypeId GetElementDType() const;
int MergeShape(const std::vector<int> &tmp);
bool IsFullyDefined(const std::vector<int> &shape) const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
private:
int index_ = -1;
std::vector<int> element_shape_;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_

View File

@ -0,0 +1,131 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensorlistreserve.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
TypeId TensorListReserve::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListReserve()->elementDType);
}
void TensorListReserve::SetElementDType(int type) {
this->primitive_->value.AsTensorListReserve()->elementDType = type;
}
int TensorListReserve::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TensorListReserve;
}
if (this->primitive_->value.type != schema::PrimitiveType_TensorListReserve) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TensorListReserveT();
if (attr == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
MS_LOG(ERROR) << "new TensorListReserveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("elementDType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListReserve's attr elementDType is not set";
return RET_ERROR;
} else {
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
TypeId TensorListReserve::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListReserve()->elementDType());
}
int TensorListReserve::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(primitive != nullptr);
MS_ASSERT(fbb != nullptr);
auto attr = primitive->value_as_TensorListReserve();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_TensorListReserve return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateTensorListReserve(*fbb, attr->elementDType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListReserve, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *TensorListReserveCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<TensorListReserve>(primitive);
}
Registry TensorListReserveRegistry(schema::PrimitiveType_TensorListReserve, TensorListReserveCreator);
#endif
int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
// input0: element_shape_tensor
// input1: num_elements
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto ele_shape_type = input0->data_type();
if (ele_shape_type != kNumberTypeInt) {
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c());
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto num_ele_type = input1->data_type();
if (num_ele_type != kNumberTypeInt) {
MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
if (input1->ElementsNum() != 1) {
MS_LOG(ERROR) << "input1->ElementsNum() must be equal to 1";
return RET_ERROR;
}
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];
auto output = reinterpret_cast<TensorList *>(outputs_[0]);
output->set_data_type(kObjectTypeTensorType);
std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>());
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum()));
output->set_shape(std::vector<int>(1, num_elements));
output->MallocTensorListData(kTypeUnknown, tmp_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/primitive_c.h"
#include "src/tensorlist.h"
#include "ir/dtype/type_id.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_
namespace mindspore {
namespace lite {
class TensorListReserve : public PrimitiveC {
public:
TensorListReserve() = default;
~TensorListReserve() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(TensorListReserve, PrimitiveC);
void SetElementDType(int type);
explicit TensorListReserve(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
TypeId GetElementDType() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_

View File

@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensorlistsetitem.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
TypeId TensorListSetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListSetItem()->elementDType);
}
void TensorListSetItem::SetElementDType(int type) {
this->primitive_->value.AsTensorListSetItem()->elementDType = type;
}
int TensorListSetItem::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TensorListSetItem;
}
if (this->primitive_->value.type != schema::PrimitiveType_TensorListSetItem) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TensorListSetItemT();
if (attr == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
MS_LOG(ERROR) << "new TensorListSetItemT value failed";
return RET_ERROR;
}
if (prim.GetAttr("elementDType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListSetItem's attr elementDType is not set";
return RET_ERROR;
} else {
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
TypeId TensorListSetItem::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListSetItem()->elementDType());
}
int TensorListSetItem::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_TensorListSetItem();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_TensorListSetItem return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateTensorListSetItem(*fbb, attr->elementDType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListSetItem, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *TensorListSetItemCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<TensorListSetItem>(primitive);
}
Registry TensorListSetItemRegistry(schema::PrimitiveType_TensorListSetItem, TensorListSetItemCreator);
#endif
int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
auto input0 = reinterpret_cast<TensorList *>(inputs_[0]);
MS_ASSERT(input0 != nullptr);
auto get_index = inputs_[1];
MS_ASSERT(get_index != nullptr);
if (get_index->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
if (get_index->ElementsNum() != 1) {
MS_LOG(ERROR) << "inputs_[1].ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
int index = reinterpret_cast<int *>(get_index->data_c())[0];
if (index < 0 || index > (input0->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]";
return RET_ERROR;
}
auto value_tensor = inputs_[2];
MS_ASSERT(value_tensor != nullptr);
auto output0 = reinterpret_cast<TensorList *>(outputs_[0]);
MS_ASSERT(output0 != nullptr);
output0->set_element_shape(input0->element_shape());
output0->set_max_elements_num(input0->max_elements_num());
output0->set_shape(input0->shape());
output0->set_data_type(input0->data_type());
std::vector<std::vector<int> > out_shape;
for (int i = 0; i < input0->ElementsNum(); ++i) {
auto src_ptr = input0->GetTensorIndex(i);
if (src_ptr == nullptr) {
MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!";
return RET_ERROR;
}
if (src_ptr->data_type() != kTypeUnknown) {
out_shape.push_back(src_ptr->shape());
} else {
out_shape.push_back(std::vector<int>());
}
}
out_shape[index] = value_tensor->shape();
output0->MallocTensorListData(input0->tensors_data_type(), out_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/primitive_c.h"
#include "src/tensorlist.h"
#include "ir/dtype/type_id.h"
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_
namespace mindspore {
namespace lite {
class TensorListSetItem : public PrimitiveC {
public:
TensorListSetItem() = default;
~TensorListSetItem() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(TensorListSetItem, PrimitiveC);
void SetElementDType(int type);
explicit TensorListSetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
TypeId GetElementDType() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_

View File

@ -0,0 +1,188 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/ops/tensorliststack.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
TypeId TensorListStack::GetElementDType() const {
return (TypeId)(this->primitive_->value.AsTensorListStack()->elementDType);
}
int TensorListStack::GetNumElements() const { return this->primitive_->value.AsTensorListStack()->numElements; }
void TensorListStack::SetElementDType(int type) { this->primitive_->value.AsTensorListStack()->elementDType = type; }
void TensorListStack::SetNumElements(int num_elements) {
this->primitive_->value.AsTensorListStack()->numElements = num_elements;
}
int TensorListStack::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TensorListStack;
}
if (this->primitive_->value.type != schema::PrimitiveType_TensorListStack) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TensorListStackT();
if (attr == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
MS_LOG(ERROR) << "new TensorListStackT value failed";
return RET_ERROR;
}
if (prim.GetAttr("elementDType") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListStack's attr elementDType is not set";
return RET_ERROR;
} else {
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
}
if (prim.GetAttr("numElements") == nullptr) {
delete this->primitive_;
this->primitive_ = nullptr;
delete attr;
MS_LOG(ERROR) << "TensorListStack's attr numElements is not set";
return RET_ERROR;
} else {
attr->numElements = CastToInt(prim.GetAttr("numElements")).front();
}
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
TypeId TensorListStack::GetElementDType() const {
return (TypeId)(this->primitive_->value_as_TensorListStack()->elementDType());
}
int TensorListStack::GetNumElements() const { return this->primitive_->value_as_TensorListStack()->numElements(); }
int TensorListStack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_TensorListStack();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_TensorListStack return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateTensorListStack(*fbb, attr->elementDType(), attr->numElements());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListStack, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *TensorListStackCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<TensorListStack>(primitive);
}
Registry TensorListStackRegistry(schema::PrimitiveType_TensorListStack, TensorListStackCreator);
#endif
bool TensorListStack::IsFullyDefined(const std::vector<int> &shape) const {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
return false;
}
}
return true;
}
int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
auto input0 = reinterpret_cast<TensorList *>(inputs_.front());
MS_ASSERT(input0 != nullptr);
if (input0->ElementsNum() == 0) {
MS_LOG(ERROR) << "Try to stack a empty tensorlist!";
return RET_ERROR;
}
auto ele_shape = inputs_[1]; // element shape
MS_ASSERT(ele_shape != nullptr);
auto ele_shape_ptr = reinterpret_cast<int *>(ele_shape->data_c());
for (int i = 0; ele_shape->ElementsNum(); ++i) {
output_shape_.push_back(ele_shape_ptr[i]);
}
auto status = MergeShape(input0->element_shape());
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge element_shape is error!";
return RET_ERROR;
}
if (!IsFullyDefined(output_shape_)) {
MS_LOG(ERROR) << "output_shape_ Is Not FullyDefined!";
return RET_ERROR;
}
if (!IsFullyDefined(input0->element_shape())) {
for (int i = 0; i < input0->ElementsNum(); ++i) {
auto tensor_ele = input0->GetTensorIndex(i);
MS_ASSERT(tensor_ele != nullptr);
if (tensor_ele->data_type() != kTypeUnknown) {
status = MergeShape(tensor_ele->shape());
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge input0->tensors_[" << i << "] is error!";
return RET_ERROR;
}
}
}
}
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->tensors_data_type());
output->set_shape(std::vector<int>(
1,
input0->ElementsNum() * std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies<int>())));
return RET_OK;
}
int TensorListStack::MergeShape(const std::vector<int> &shape) {
size_t dim0 = shape.size();
size_t dim1 = output_shape_.size();
if (dim1 >= unKnownRank_) {
output_shape_ = shape;
return RET_OK;
}
if (dim1 != dim0) {
MS_LOG(ERROR) << "shape.size():" << dim1 << " must be equal output_shape_.size():" << dim0;
return RET_ERROR;
}
for (size_t i = 0; i < dim0; ++i) {
int dim0_size = shape[i];
int dim1_size = output_shape_[i];
if (dim0_size >= 0 && dim1_size >= 0 && dim0_size != dim1_size) {
MS_LOG(ERROR) << "shape[" << i << "]:" << dim0_size << " is incompatible with output_shape_[" << i
<< "]:" << dim1_size;
return RET_ERROR;
}
output_shape_[i] = dim1_size >= 0 ? dim1_size : dim0_size;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -16,49 +16,27 @@
#include <vector>
#include <functional>
#include "src/ops/primitive_c.h"
#include "src/tensorlist.h"
#include "ir/dtype/type_id.h"
using mindspore::schema::Format;
using mindspore::schema::Format_NC;
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSTACK_H_
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSTACK_H_
namespace mindspore {
namespace lite {
class TensorListFromTensor : public PrimitiveC {
public:
TypeId GetElementDType() const;
TypeId GetShapeType() const;
TensorListFromTensor() = default;
bool IsCompatibleShape(std::vector<lite::Tensor *> inputs_);
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
class TensorListReserve : public PrimitiveC {
public:
TensorListReserve() = default;
TypeId GetElementDType() const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
class TensorListGetItem : public PrimitiveC {
public:
TensorListGetItem() = default;
TypeId GetElementDType() const;
bool IsFullyDefined(const std::vector<int> &shape) const;
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
private:
int index_ = -1;
TypeId element_dtype_ = kTypeUnknown;
std::vector<int> element_shape_;
};
class TensorListStack : public PrimitiveC {
public:
// tensor:input, element_dtype, num_elements(default=-1:reprent any tensor dim0), element_shape
TensorListStack() = default;
~TensorListStack() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(TensorListStack, PrimitiveC);
void SetElementDType(int type);
void SetNumElements(int num_elements);
explicit TensorListStack(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
TypeId GetElementDType() const;
int GetNumElements() const;
bool IsFullyDefined(const std::vector<int> &shape) const;
@ -72,4 +50,4 @@ class TensorListStack : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSTACK_H_

View File

@ -27,63 +27,67 @@ using mindspore::schema::PrimitiveType_TensorListFromTensor;
namespace mindspore::kernel {
bool TensorListFromTensorCPUKernel::IsCompatibleShape() {
int TensorListFromTensorCPUKernel::IsCompatibleShape() {
if (input1_->data_type() != kNumberTypeInt) { // element_shape
MS_LOG(ERROR) << "in_tensors_[1] data type is must be \"kNumberTypeInt\", but now is:" << input1_->data_type();
return false;
return RET_ERROR;
}
int in1_ele_num = input1_->ElementsNum();
std::vector<int> tensor_shape = input0_->shape();
if (static_cast<int>(tensor_shape.size() - 1) != in1_ele_num) {
MS_LOG(ERROR) << "in_tensors_[0].shape() - 1:" << tensor_shape.size() - 1
MS_LOG(ERROR) << "in_tensors_[0].shape().size() - 1:" << tensor_shape.size() - 1
<< " must be equal in_tensors_[1].ElementsNum():" << in1_ele_num;
return false;
return RET_ERROR;
}
int *elements_shape = reinterpret_cast<int *>(input1_->data_c()); // element shape in tensor data
for (int i = 0; i < in1_ele_num; ++i) {
const int dim0 = tensor_shape.at(i + 1);
const int dim1 = *(elements_shape + i);
int dim0 = tensor_shape[i + 1];
int dim1 = elements_shape[i];
if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
MS_LOG(ERROR) << "input0_->shape()[" << i + 1 << "]:" << dim0 << " is not equal input1_->data_c()[" << i
<< "]:" << dim1;
return false;
return RET_ERROR;
}
}
return true;
return RET_OK;
}
int TensorListFromTensorCPUKernel::Init() {
input0_ = in_tensors_.at(0); // row tensor
input1_ = in_tensors_.at(1); // element_shape tensor
output0_ = out_tensors_.at(0);
output1_ = out_tensors_.at(1);
input0_ = in_tensors_[0]; // row tensor
input1_ = in_tensors_[1]; // element_shape tensor
output0_ = out_tensors_[0];
return IsCompatibleShape();
}
int TensorListFromTensorCPUKernel::ReSize() { return RET_OK; }
int TensorListFromTensorCPUKernel::Run() {
int dim0 = input0_->shape().at(0);
size_t devision_dim0 = input0_->ElementsNum() / dim0;
auto out0_ptr = reinterpret_cast<int *>(output0_->MutableData());
*out0_ptr = dim0;
*(out0_ptr + 1) = input0_->data_type();
auto status = output1_->CopyTensorData(*input1_);
if (status == RET_ERROR) {
MS_LOG(ERROR) << "copy tensor data failed!";
if (input0_->shape().size() == 0) {
MS_LOG(ERROR) << "input0_->shape().size():" << input0_->shape().size() << " must be greater than 0";
}
int dim0 = input0_->shape()[0];
if (dim0 <= 0) {
MS_LOG(ERROR) << "input0_->shape()[0]:" << dim0 << " must be greater than 0!";
return RET_ERROR;
}
if (dim0 != static_cast<int>(out_tensors_.size() - 2)) {
MS_LOG(ERROR) << "out_tensors_.size() - 2:[" << out_tensors_.size() - 2
<< "] must be equal in_tensors_[0].shape()[0]:[" << dim0 << "]";
auto output0 = reinterpret_cast<lite::TensorList *>(output0_);
if (dim0 != output0->ElementsNum()) {
MS_LOG(ERROR) << "output0_->ElementsNum():" << output0->ElementsNum() << " must be equal to dim0:" << dim0;
return RET_ERROR;
}
auto in_ptr = reinterpret_cast<float *>(input0_);
size_t index = 0;
int devision_dim0 = input0_->ElementsNum() / dim0;
auto in_ptr = reinterpret_cast<float *>(input0_->data_c());
// copy data from input0(tensor) to output(tensorlist) vector<*tensor>
for (int i = 0; i < dim0; ++i) {
auto out_ptr = reinterpret_cast<float *>(out_tensors_.at(i + 2)->MutableData());
memcpy(out_ptr, in_ptr + index, devision_dim0 * sizeof(float));
index += devision_dim0;
auto out_ptr = output0->GetTensorIndex(i);
MS_ASSERT(out_ptr != nullptr);
if (out_ptr->ElementsNum() != devision_dim0) {
MS_LOG(ERROR) << "tensors_[" << i << "].ElementsNum():" << out_ptr->ElementsNum()
<< " must be euqal to devision_dim0:" << devision_dim0;
return RET_ERROR;
}
memcpy(reinterpret_cast<float *>(out_ptr->MutableData()), in_ptr, devision_dim0 * sizeof(float));
in_ptr += devision_dim0;
}
return RET_OK;
}

View File

@ -19,18 +19,12 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/tensorlist.h"
#include "schema/model_generated.h"
namespace mindspore::kernel {
class TensorListFromTensorCPUKernel : public LiteKernel {
public:
/*
* input0:tensor
* input1:element_shape
* output0:tensorlist.size() and dty pe
* output2~n:tensor
* output1:element_shape(tensorlist shape)
*/
TensorListFromTensorCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
@ -40,12 +34,11 @@ class TensorListFromTensorCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
int Run() override;
bool IsCompatibleShape();
int IsCompatibleShape();
private:
std::vector<int> output_shape_;
lite::Tensor *output0_ = nullptr;
lite::Tensor *output1_ = nullptr;
lite::Tensor *input0_ = nullptr;
lite::Tensor *input1_ = nullptr;
};

View File

@ -29,37 +29,43 @@ using mindspore::schema::PrimitiveType_TensorListGetItem;
namespace mindspore::kernel {
int TensorListGetItemCPUKernel::Init() {
auto input0 = reinterpret_cast<int *>(in_tensors_.at(0)->data_c());
size_t dim0 = *input0;
int in_dtype = *(input0 + 1);
if (dtype_ != in_dtype) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors dtype:" << in_dtype;
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
if (dtype_ != input0->tensors_data_type()) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0->tensors_data_type();
return RET_ERROR;
}
index_ = *(reinterpret_cast<int *>(in_tensors_.at(dim0 + 2)->data_c()));
if (index_ < 0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be greater than or equal to 0";
if (in_tensors_[1]->ElementsNum() != 1) {
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
if (index_ > dim0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be less than dim0:" << dim0;
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
int dim0 = input0->ElementsNum() - 1;
if (index_ < 0 || index_ > dim0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!";
return RET_ERROR;
}
index_ += 2;
return RET_OK;
}
int TensorListGetItemCPUKernel::Run() {
if (in_tensors_.at(index_)->data_type() != kTypeUnknown) {
auto status = out_tensors_.at(0)->CopyTensorData(*in_tensors_.at(index_)); // tensorlist shape
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
auto src_ptr = input0->GetTensorIndex(index_);
MS_ASSERT(src_ptr != nullptr);
if (src_ptr->data_type() != kTypeUnknown) {
if (src_ptr->ElementsNum() != out_tensors_[0]->ElementsNum()) {
MS_LOG(ERROR) << "src_ptr->ElementsNum():" << src_ptr->ElementsNum()
<< " must be equal to out_tensors_[0]->ElementsNum():" << out_tensors_[0]->ElementsNum();
return RET_ERROR;
}
auto status = out_tensors_[0]->CopyTensorData(*src_ptr);
if (status == RET_ERROR) {
MS_LOG(ERROR) << "copy tensor data failed!";
return RET_ERROR;
}
} else {
// reset 0 and dtype = dtype_
auto out_ptr = reinterpret_cast<char *>(out_tensors_.at(0)->MutableData());
memset(out_ptr, 0, lite::DataTypeSize(dtype_) * out_tensors_.at(0)->ElementsNum());
// TODO(DT_VARIANT): dtype = DT_VARIANT is not handle
memset(out_tensors_[0]->MutableData(), 0, out_tensors_[0]->Size());
}
return RET_OK;
}

View File

@ -19,6 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/tensorlist.h"
#include "schema/model_generated.h"
#include "nnacl/tensorlist_parameter.h"
@ -37,7 +38,7 @@ class TensorListGetItemCPUKernel : public LiteKernel {
int Run() override;
private:
size_t index_ = 0;
int index_ = 0;
TypeId dtype_ = kTypeUnknown;
};
} // namespace mindspore::kernel

View File

@ -30,18 +30,8 @@ namespace mindspore::kernel {
int TensorListReserveCPUKernel::Init() { return RET_OK; }
int TensorListReserveCPUKernel::Run() {
auto out0_ptr = reinterpret_cast<int *>(out_tensors_.at(0)->MutableData()); // tensorlist size() and dtype
out0_ptr[0] = reinterpret_cast<int *>(in_tensors_.at(0)->data_c())[0]; // num_elements
out0_ptr[1] = element_dtype_;
auto status = out_tensors_.at(1)->CopyTensorData(*in_tensors_.at(1)); // elements_shape
if (status == RET_ERROR) {
MS_LOG(ERROR) << "copy tensor data failed!";
return RET_ERROR;
}
if (static_cast<int>(out_tensors_.size() - 2) != out0_ptr[0]) {
MS_LOG(ERROR) << "out_tensors_.size() - 2:" << out_tensors_.size() - 2
<< " must be equal num_elements:" << out0_ptr[0];
}
auto output = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
output->set_tensors_data_type(element_dtype_);
return RET_OK;
}

View File

@ -19,6 +19,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/tensorlist.h"
#include "schema/model_generated.h"
#include "nnacl/tensorlist_parameter.h"
@ -37,7 +38,7 @@ class TensorListReserveCPUKernel : public LiteKernel {
int Run() override;
private:
int element_dtype_ = 0;
TypeId element_dtype_ = kTypeUnknown;
};
} // namespace mindspore::kernel

View File

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/errorcode.h"
#include "include/ms_tensor.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/fp32/TensorListSetItem.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_TensorListSetItem;
namespace mindspore::kernel {
int TensorListSetItemCPUKernel::Init() {
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
if (dtype_ != input0_->data_type()) {
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
return RET_ERROR;
}
int dim0 = input0_->ElementsNum() - 1;
if (in_tensors_[1]->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
if (in_tensors_[1]->ElementsNum() != 1) {
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
return RET_ERROR;
}
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
if (index_ < 0 || index_ > dim0) {
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!";
return RET_ERROR;
}
input2_ = in_tensors_[2];
MS_ASSERT(input2_ != nullptr);
if (!input0_->IsCompatibleShape(input2_->shape())) {
return RET_ERROR;
}
return RET_OK;
}
int TensorListSetItemCPUKernel::Run() {
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
MS_ASSERT(output0_ != nullptr);
// copy each tensor in tensors_
for (int i = 0; i < output0_->ElementsNum(); ++i) {
auto dst = output0_->GetTensorIndex(i);
MS_ASSERT(dst != nullptr);
auto src = input0_->GetTensorIndex(i);
if (i == index_) {
// copy input2_ data buff
src = input2_;
}
MS_ASSERT(src != nullptr);
if (src->data_type() != kTypeUnknown) {
if (src->Size() != dst->Size()) {
MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size();
return RET_ERROR;
}
auto ret = dst->CopyTensorData(*src);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!";
return RET_ERROR;
}
}
}
return RET_OK;
}
int TensorListSetItemCPUKernel::ReSize() { return RET_OK; }
kernel::LiteKernel *CpuTensorListSetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *op_parameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "Input context is nullptr!";
free(op_parameter);
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_TensorListSetItem);
auto *kernel = new (std::nothrow) TensorListSetItemCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new TensorListSetItemCPUKernel fail!";
free(op_parameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListSetItem, CpuTensorListSetItemFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/tensorlist.h"
#include "schema/model_generated.h"
#include "nnacl/tensorlist_parameter.h"
namespace mindspore::kernel {
class TensorListSetItemCPUKernel : public LiteKernel {
public:
TensorListSetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive),
dtype_(reinterpret_cast<TensorListParameter *>(parameter)->element_dtype_) {}
~TensorListSetItemCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
lite::TensorList *input0_ = nullptr;
lite::Tensor *input2_ = nullptr;
lite::TensorList *output0_ = nullptr;
int index_ = 0;
TypeId dtype_ = kTypeUnknown;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_

View File

@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <functional>
#include <vector>
#include "include/errorcode.h"
#include "ir/dtype/type_id.h"
@ -29,58 +31,139 @@ using mindspore::schema::PrimitiveType_TensorListStack;
namespace mindspore::kernel {
int TensorListStackCPUKernel::CheckParam() {
auto in0_dtype = in_tensors_.at(0)->data_type();
if (in0_dtype != kNumberTypeInt) {
MS_LOG(ERROR) << "in_tensors_[0]->data_type():" << in0_dtype
<< " must be equal \"kNumberTypeInt\":" << kNumberTypeInt;
}
auto in0_ptr = reinterpret_cast<int *>(in_tensors_.at(0)->data_c());
if (in0_ptr[1] != dtype_) {
MS_LOG(ERROR) << "in_tensors_[0].data_type:[" << in0_ptr[1] << "] must be equal "
if (input0_->tensors_data_type() != dtype_) {
MS_LOG(ERROR) << "in_tensors_[0].tensors_data_type:[" << input0_->tensors_data_type() << "] must be equal "
<< "param.data_type:[" << dtype_ << "]";
return RET_ERROR;
}
if (num_element_ != -1 && in0_ptr[0] != num_element_) {
MS_LOG(ERROR) << "in_tensors_[0].dim0:[" << in0_ptr[0] << "] must be equal "
if (num_element_ != -1 && input0_->ElementsNum() != num_element_) {
MS_LOG(ERROR) << "in_tensors_[0].ElementsNum():[" << input0_->ElementsNum() << "] must be equal "
<< "param.elements_num:[" << num_element_ << "]";
return RET_ERROR;
}
num_element_ = in0_ptr[0];
num_element_ = input0_->ElementsNum();
return RET_OK;
}
int TensorListStackCPUKernel::Init() {
output0_ = out_tensors_.at(0);
if (output0_->format() != schema::Format_NC) { // shape().size() = 2
MS_LOG(ERROR) << "out_tensor_[0] format must be \"Format:NC\", but now is:" << output0_->format();
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
MS_ASSERT(input0_ != nullptr);
output0_ = out_tensors_[0];
MS_ASSERT(output0_ != nullptr);
if (output0_->shape().size() != 2) {
MS_LOG(ERROR) << "out_tensors_[0].shape().size():" << output0_->shape().size() << " must be equal to 2!";
return RET_ERROR;
}
int dim0 = output0_->shape().at(0);
int dim0 = output0_->shape()[0];
if (dim0 != 1) { // dim0 must be 1
MS_LOG(ERROR) << "out_tensor_[0] dim0 must be 1, but now is:" << dim0;
MS_LOG(ERROR) << "out_tensors_[0].shape()[0] must be 1, but now is:" << dim0;
return RET_ERROR;
}
return CheckParam();
}
int TensorListStackCPUKernel::Run() {
size_t in_ele_num = 0;
for (int i = 0; i < num_element_; ++i) {
in_ele_num += in_tensors_.at(i + 2)->ElementsNum();
bool TensorListStackCPUKernel::IsFullyDefined(const std::vector<int> &shape) const {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
return false;
}
}
size_t out_ele_num = out_tensors_.at(0)->ElementsNum();
if (in_ele_num > out_ele_num) {
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must greater than or equal to in_ele_num"
<< in_ele_num;
return true;
}
int TensorListStackCPUKernel::MergeElementShape() {
MS_ASSERT(in_tensors_[1]);
if (in_tensors_[1]->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
return RET_ERROR;
}
size_t index = 0;
auto out_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto ele_shape_data = reinterpret_cast<int *>(in_tensors_[1]->data_c());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
output_shape_.push_back(ele_shape_data[i]);
}
auto status = MergeSubShape(input0_->element_shape());
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge element_shape is error!";
return RET_ERROR;
}
if (!IsFullyDefined(output_shape_)) {
MS_LOG(ERROR) << "output_shape_ Is Not FullyDefined!";
return RET_ERROR;
}
if (!IsFullyDefined(input0_->element_shape())) {
for (int i = 0; i < input0_->ElementsNum(); ++i) { // get tensorlist every tensor
auto tensor_ele = input0_->GetTensorIndex(i);
MS_ASSERT(tensor_ele != nullptr);
if (tensor_ele->data_type() != kTypeUnknown) {
status = MergeSubShape(tensor_ele->shape());
if (status == RET_ERROR) {
MS_LOG(ERROR) << "Merge tensors_[" << i << "] is error!";
return RET_ERROR;
}
}
}
}
TypeUnknownSize = std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies<int>());
return RET_OK;
}
int TensorListStackCPUKernel::MergeSubShape(const std::vector<int> &shape) {
size_t dim0 = shape.size();
size_t dim1 = output_shape_.size();
if (dim1 != dim0) {
MS_LOG(ERROR) << "shape.size():" << dim1 << " must be equal output_shape_.size():" << dim0;
return RET_ERROR;
}
for (size_t i = 0; i < dim0; ++i) {
int dim0_size = shape[i];
int dim1_size = output_shape_[i];
if (dim0_size >= 0 && dim1_size >= 0 && dim0_size != dim1_size) {
MS_LOG(ERROR) << "shape[" << i << "]:" << dim0_size << " is incompatible with output_shape_[" << i
<< "]:" << dim1_size;
return RET_ERROR;
}
output_shape_[i] = dim1_size >= 0 ? dim1_size : dim0_size;
}
return RET_OK;
}
int TensorListStackCPUKernel::Run() {
if (output0_->ElementsNum() == 0) {
return RET_OK;
}
size_t in_ele_num = 0;
for (int i = 0; i < num_element_; ++i) {
auto in_ptr = reinterpret_cast<float *>(in_tensors_.at(i + 2)->data_c());
size_t in_size = in_tensors_.at(i + 2)->ElementsNum();
memcpy(out_ptr + index, in_ptr, in_size * sizeof(float));
index += in_size;
auto tensor = input0_->GetTensorIndex(i);
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kTypeUnknown) {
if (TypeUnknownSize == 0) {
TypeUnknownSize = MergeElementShape();
}
in_ele_num += TypeUnknownSize;
} else {
in_ele_num += std::accumulate(tensor->shape().begin(), tensor->shape().end(), 1LL, std::multiplies<int>());
}
}
size_t out_ele_num = output0_->ElementsNum();
if (in_ele_num > out_ele_num) {
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num
<< "must be greater than or equal to in_ele_num:" << in_ele_num;
return RET_ERROR;
}
auto out_ptr = reinterpret_cast<float *>(output0_->MutableData());
for (int i = 0; i < num_element_; ++i) {
auto in_ptr = input0_->GetTensorIndex(i);
MS_ASSERT(in_ptr != nullptr);
if (in_ptr->data_type() != kTypeUnknown) {
int in_size = in_ptr->ElementsNum();
memcpy(out_ptr, in_ptr->data_c(), in_size * sizeof(float));
out_ptr += in_size;
} else {
memset(out_ptr, 0, TypeUnknownSize * sizeof(float));
out_ptr += TypeUnknownSize;
}
}
return RET_OK;
}

View File

@ -18,7 +18,9 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSTACK_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/tensorlist.h"
#include "schema/model_generated.h"
#include "nnacl/tensorlist_parameter.h"
@ -37,11 +39,17 @@ class TensorListStackCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int CheckParam();
int MergeElementShape();
int MergeSubShape(const std::vector<int> &shape);
bool IsFullyDefined(const std::vector<int> &shape) const;
private:
size_t TypeUnknownSize = 0;
int num_element_ = -1;
TypeId dtype_ = kTypeUnknown;
lite::TensorList *input0_ = nullptr;
lite::Tensor *output0_ = nullptr;
std::vector<int> output_shape_;
};
} // namespace mindspore::kernel

View File

@ -92,15 +92,15 @@ class Tensor : public mindspore::tensor::MSTensor {
mindspore::lite::Allocator *allocator() const { return this->allocator_; }
int MallocData(const mindspore::lite::Allocator *allocator = nullptr);
virtual int MallocData(const mindspore::lite::Allocator *allocator = nullptr);
int FreeData();
virtual int FreeData();
void *MutableData() override;
void *data_c() const { return data_; }
virtual void *data_c() const { return data_; }
void set_data(void *data) { this->data_ = data; }
virtual void set_data(void *data) { this->data_ = data; }
Category category() { return this->category_; }
@ -189,6 +189,8 @@ inline size_t DataTypeSize(const TypeId type) {
return sizeof(bool);
case kObjectTypeString:
return sizeof(char);
case kObjectTypeTensorType:
return 0;
default:
MS_LOG(ERROR) << "Not support the type: " << type;
return 0;

View File

@ -0,0 +1,245 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/ms_tensor.h"
#include "src/common/log_adapter.h"
#include "schema/model_generated.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
namespace mindspore {
namespace lite {
TensorList::TensorList(std::vector<int> shape, std::vector<int> element_shape)
: Tensor(kObjectTypeTensorType, shape), element_shape_(element_shape) {}
TensorList::~TensorList() {
if (!this->tensors_.empty()) {
this->FreeData();
this->FreeTensorListData();
}
}
TensorList &TensorList::operator=(const TensorList &src) {
if (&src == this) {
return *this;
}
auto ret = CopyTensorList(src, true);
if (ret == RET_ERROR) {
MS_LOG(ERROR) << "CopyTensorList error!";
MS_ASSERT(false);
}
return *this;
}
int TensorList::CopyTensorList(const TensorList &src, bool copy_data) {
this->data_type_ = src.data_type_;
this->tensors_data_type_ = src.tensors_data_type_;
this->shape_ = src.shape_;
this->element_shape_ = src.element_shape_;
this->max_elements_num_ = src.max_elements_num_;
if (copy_data) {
auto ret = CopyTensorData(src);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CopyTensorData error";
return RET_ERROR;
}
} else {
// each tensor in tensors_ will share the same memory space.
this->tensors_ = src.tensors_;
}
return RET_OK;
}
int TensorList::CopyTensorData(const TensorList &src) {
for (int i = 0; i < this->ElementsNum(); ++i) {
if (src.tensors_[i] == nullptr) {
MS_LOG(ERROR) << "src tensors_[" << i << "] is nullptr!";
return RET_ERROR;
}
auto dst_tensor = new (std::nothrow) Tensor;
if (dst_tensor == nullptr) {
MS_LOG(ERROR) << "CopyTensorData: new tensor[" << i << "] is failed!";
return RET_ERROR;
}
*reinterpret_cast<Tensor *>(dst_tensor) = *src.tensors_[i];
this->tensors_.push_back(dst_tensor);
}
return RET_OK;
}
int TensorList::MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape) {
// This function will create a new tensors_
// Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in
// tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_.
if (!this->tensors_.empty()) {
// If tensors_ is not empty then clear this tensors_ and rebuild a new tensors_.
auto ret = FreeTensorListData();
if (ret != RET_OK) {
return RET_ERROR;
}
}
if (this->shape().size() != 1) {
MS_LOG(ERROR) << "tensorlist shape:" << this->shape().size() << " must be one-dimensional";
return RET_ERROR;
}
if (static_cast<size_t>(this->ElementsNum()) != tensor_shape.size()) {
MS_LOG(ERROR) << "tensorlist ElementsNum():" << this->ElementsNum()
<< " must be equal to param2:tensor_shape.size():" << tensor_shape.size();
return RET_ERROR;
}
this->tensors_data_type_ = dtype;
for (int i = 0; i < this->ElementsNum(); ++i) {
auto tensor_ptr = new (std::nothrow) Tensor(dtype, tensor_shape[i]);
if (tensor_ptr == nullptr) {
MS_LOG(ERROR) << "new tensors_[" << i << "] is failed!";
return RET_ERROR;
}
this->tensors_.push_back(tensor_ptr);
}
return RET_OK;
}
int TensorList::MallocData(const mindspore::lite::Allocator *allocator) {
// malloc data buf of each tensor in tensors_
for (int i = 0; i < this->ElementsNum(); ++i) {
auto tensor_ptr = this->tensors_[i];
if (tensor_ptr == nullptr) {
MS_LOG(ERROR) << "tensors_[" << i << "] is nullptr!";
return RET_ERROR;
}
// if data_type() is kTypeUnknown then data buf will not to be malloc
if (tensor_ptr->data_type() != kTypeUnknown) {
auto ret = tensor_ptr->MallocData(this->allocator_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "tensorlist malloc tensors_[:" << i << "] is failed!";
return RET_ERROR;
}
}
}
return RET_OK;
}
int TensorList::FreeData() {
// free data buf of each tensor in tensors_
if (this->tensors_.empty()) {
return RET_OK;
}
for (int i = 0; i < this->ElementsNum(); ++i) {
if (this->tensors_[i] != nullptr) {
this->tensors_[i]->FreeData();
}
}
return RET_OK;
}
int TensorList::FreeTensorListData() {
// del each tensor in tensors_ and clear tensors_
if (this->tensors_.empty()) {
return RET_OK;
}
for (int i = 0; i < this->ElementsNum(); ++i) {
if (this->tensors_[i] != nullptr) {
delete this->tensors_[i];
this->tensors_[i] = nullptr;
}
}
tensors_.clear();
return RET_OK;
}
int TensorList::SetTensorIndex(int index, Tensor *src_tensor) {
// your can use this fun to modify tensor[index] value
if (src_tensor->data_type() != this->tensors_data_type_) {
MS_LOG(ERROR) << "src_tensor->data_type()" << src_tensor->data_type()
<< " must be equal to tensors_data_type_:" << this->tensors_data_type_;
return RET_ERROR;
}
if (index < 0 || index > (this->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return RET_ERROR;
}
auto dst_tensor = this->tensors_[index];
if (dst_tensor != nullptr) { // free original tensor data
delete dst_tensor;
}
this->tensors_[index] = new (std::nothrow) Tensor;
if (this->tensors_[index] == nullptr) {
MS_LOG(ERROR) << "SetTensorIndex: new tensor is failed!";
return RET_ERROR;
}
*this->tensors_[index] = *src_tensor;
return RET_OK;
}
int TensorList::CheckTensorListParam() {
for (int i = 0; i < this->ElementsNum(); ++i) {
// each tensor in tensorlist must be not nullptr
if (this->tensors_[i] == nullptr) {
MS_LOG(ERROR) << "CheckTensorListParam: tensors_[" << i << "] is nullptr";
return RET_ERROR;
}
if (this->tensors_[i]->data_type() != this->tensors_data_type_) {
MS_LOG(ERROR) << "CheckTensorListParam: tensors_[i] data_type:" << this->tensors_[i]->data_type()
<< " is not equal to tensors_data_type_:" << this->tensors_data_type_;
return RET_ERROR;
}
}
return RET_OK;
}
Tensor *TensorList::GetTensorIndex(int index) {
// return tensor[index] ptr. With this function, you can modify tensors_[index] at will.
if (index < 0 || index > (this->ElementsNum() - 1)) {
MS_LOG(ERROR) << "index:" << index << " must in [0, " << this->ElementsNum() - 1 << "]!";
return nullptr;
}
return this->tensors_[index];
}
bool TensorList::IsCompatibleShape(const std::vector<int> &shape) {
if (shape.size() != this->element_shape_.size()) {
return false;
}
for (size_t i = 0; i < shape.size(); ++i) {
if (this->element_shape_[i] >= 0 && shape[i] >= 0 && this->element_shape_[i] != shape[i]) {
return false;
}
}
return true;
}
bool TensorList::IsCompatibleShape(const Tensor *src) {
// shape is store in Tensor.
if (static_cast<size_t>(src->ElementsNum()) != this->element_shape_.size()) {
return false;
}
if (src->data_type() != kNumberTypeInt) {
MS_LOG(ERROR) << "src tensor data_type:" << src->data_type()
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
return false;
}
auto src_ptr = reinterpret_cast<int *>(src->data_c());
for (size_t i = 0; i < this->element_shape_.size(); ++i) {
if (this->element_shape_[i] >= 0 && src_ptr[i] >= 0 && this->element_shape_[i] != src_ptr[i]) {
return false;
}
}
return true;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,129 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_TENSORLIST_H_
#define MINDSPORE_LITE_SRC_TENSORLIST_H_
#include <memory>
#include <vector>
#include "include/ms_tensor.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "schema/model_generated.h"
#include "src/tensor.h"
namespace mindspore {
namespace lite {
/**
* Tensorlist is a container of vector, in which each element is a tensor object.
* Member objects:
* 1.tensors_: tensors_ is a vector, where each element is a pointer to tensor type.
* 2.shape_: represents the size of the tensors_ and shape_.size() must be equal to 1.
* 3.element_shape_: element_shape_ represents the shape of each tensor in tensors_.
* Some dimensions can be negative, which means that the corresponding dimensions of each tensor in tensors_ can be
* different.
* 4.data_type_: indicates that the tensorlist is a tensor of type kObjectTypeTensorType, so it can only be
* "kObjectTypeTensorType"
* 5.tensors_data_type_: data_type_ of each tensor in tensors_
* Usage:
* std::vector<int> shape = (1, 2); // tensors_ only has two tensor
* std::vector<int> element_shape = {-1, 99};
* // dim0 is arbitrary and dim1 is must to be 99 of each tensor.shape() in tensors_
* TensorList *tl = new TensorList(shape, element_shape);
* std::vector<std::vector<int> > tensor_shape = std::vector<vector<int> > (2,
* (std::vector<int> {5, 99},
* std::vector<int> {1, 99}));
* // tensor_shape[0] and tensor_shape[1] is not equal in dim0, but dim1 is must be equal to 99.
* t1->MallocTensorListData(kNumberTypeFloat, tensor_shape);
* t1->MallocData();
* t1->...
* ...
* t1->FreeData();
* t1->FreeTensorListData();
*
* See the code for other constructors.
*/
class TensorList : public Tensor {
public:
TensorList() = default;
~TensorList() override;
// **Note**: This is a shallow copy, src and dst tensorlist share one memory space of each tensor in tensors_
// If your want to not share one memory space please use "operator="
TensorList(const TensorList &other)
: Tensor(other.data_type_, other.shape()),
tensors_(other.tensors_),
tensors_data_type_(other.tensors_data_type_),
element_shape_(other.element_shape_),
max_elements_num_(other.max_elements_num_) {}
// tensorlist deep copy memory
TensorList &operator=(const TensorList &tl);
TensorList(std::vector<int> shape, std::vector<int> element_shape);
void set_element_shape(const std::vector<int> &shape) { element_shape_ = shape; }
std::vector<int> &element_shape() { return element_shape_; }
void set_max_elements_num(int ele_num) { max_elements_num_ = ele_num; }
int max_elements_num() const { return max_elements_num_; }
int MallocTensorListData(TypeId dtype, const std::vector<std::vector<int> > &tensor_shape);
int MallocData(const mindspore::lite::Allocator *allocator = nullptr) override;
int FreeTensorListData();
int FreeData() override;
int CopyTensorList(const TensorList &src, bool copy_data);
int CopyTensorData(const TensorList &src);
int SetTensorIndex(int index, Tensor *);
Tensor *GetTensorIndex(int index);
void set_tensors_data_type(TypeId type) { tensors_data_type_ = type; }
TypeId tensors_data_type() const { return tensors_data_type_; }
std::vector<Tensor *> &tensors() { return tensors_; }
int CheckTensorListParam();
bool IsCompatibleShape(const std::vector<int> &shape);
bool IsCompatibleShape(const Tensor *src);
protected:
// The following functions must be masked.
void set_data(void *data) override { return; }
void *data_c() const override { return nullptr; }
void *MutableData() override { return nullptr; }
size_t Size() const override { return 0; }
std::vector<Tensor *> tensors_;
TypeId tensors_data_type_;
std::vector<int> element_shape_;
int max_elements_num_ = -1;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_TENSORLIST_H_

View File

@ -122,6 +122,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/runtime/thread_pool.c
${LITE_DIR}/src/runtime/parallel_executor.cc
${LITE_DIR}/src/tensor.cc
${LITE_DIR}/src/tensorlist.cc
${LITE_DIR}/src/executor.cc
${LITE_DIR}/src/inner_context.cc
${LITE_DIR}/src/kernel_registry.cc

View File

@ -79,6 +79,7 @@ set(LITE_SRC
${SRC_DIR}/runtime/thread_pool.c
${SRC_DIR}/inner_context.cc
${SRC_DIR}/tensor.cc
${SRC_DIR}/tensorlist.cc
${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/lite_kernel.cc
${SRC_DIR}/scheduler.cc