forked from mindspore-Ecosystem/mindspore
Modify implementation of constant_of_shape.
Fix bug of implementation of cast and reduce and infershape of slice and topk.
This commit is contained in:
parent
c962ccbe07
commit
fef75213ed
|
@ -657,6 +657,22 @@ int ElementAddRelu6(float *input0, float *input1, float *output, int element_siz
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementAddInt(int *input0, int *input1, int *output, int element_size) {
|
||||
int index = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 4; index += C4NUM) {
|
||||
int32x4_t vin0 = vld1q_s32(input0 + index);
|
||||
int32x4_t vin1 = vld1q_s32(input1 + index);
|
||||
int32x4_t vout = vaddq_s32(vin0, vin1);
|
||||
vst1q_s32(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
output[index] = input0[index] + input1[index];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementAddInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = input0[i] + input1[i];
|
||||
|
|
|
@ -54,6 +54,7 @@ int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_i
|
|||
int ElementAdd(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementAddRelu(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementAddRelu6(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementAddInt(int *input0, int *input1, int *output, int element_size);
|
||||
int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output,
|
||||
|
|
|
@ -26,3 +26,14 @@ int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param) {
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param) {
|
||||
int size = param->unit_;
|
||||
float data = param->value_;
|
||||
int ind_st = MSMIN(tid * size, param->element_sz_);
|
||||
int ind_end = MSMIN(param->element_sz_, (tid + 1) * size);
|
||||
for (int i = ind_st; i < ind_end; ++i) {
|
||||
output[i] = data;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
typedef struct ConstantOfShapeParameter {
|
||||
OpParameter op_parameter_;
|
||||
float value_;
|
||||
int data_type_;
|
||||
int unit_;
|
||||
int element_sz_;
|
||||
} ConstantOfShapeParameter;
|
||||
|
@ -33,6 +34,7 @@ typedef struct ConstantOfShapeParameter {
|
|||
extern "C" {
|
||||
#endif
|
||||
int ConstantOfShape(float *output, int tid, ConstantOfShapeParameter *param);
|
||||
int ConstantOfShapeInt(int32_t *output, int tid, ConstantOfShapeParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -123,6 +123,27 @@ int ReduceMin(const int outer_size, const int inner_size, const int axis_size, c
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data,
|
||||
const int tid, const int thread_num) {
|
||||
if (src_data == NULL || dst_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
int i, j, k;
|
||||
for (j = tid; j < outer_size; j += thread_num) {
|
||||
const int *outer_src = src_data + j * axis_size * inner_size;
|
||||
int *outer_dst = dst_data + j * inner_size;
|
||||
for (k = 0; k < inner_size; k++) {
|
||||
const int *inner_src = outer_src + k;
|
||||
int *inner_dst = outer_dst + k;
|
||||
int tmp = INT32_MAX;
|
||||
for (i = 0; i < axis_size; i++) {
|
||||
tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size];
|
||||
}
|
||||
*inner_dst = tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
|
||||
const int tid, const int thread_num) {
|
||||
if (src_data == NULL || dst_data == NULL) {
|
||||
|
|
|
@ -30,6 +30,8 @@ int ReduceMax(const int outer_size, const int inner_size, const int axis_size, c
|
|||
const int tid, const int thread_num);
|
||||
int ReduceMin(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
|
||||
const int tid, const int thread_num);
|
||||
int IntReduceMin(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data,
|
||||
const int tid, const int thread_num);
|
||||
int ReduceProd(const int outer_size, const int inner_size, const int axis_size, const float *src_data, float *dst_data,
|
||||
const int tid, const int thread_num);
|
||||
int IntReduceProd(const int outer_size, const int inner_size, const int axis_size, const int *src_data, int *dst_data,
|
||||
|
|
|
@ -308,7 +308,8 @@ table Shape {
|
|||
}
|
||||
|
||||
table ConstantOfShape{
|
||||
value: float = 0;
|
||||
dataType: int;
|
||||
value: [float];
|
||||
}
|
||||
|
||||
table Nchw2Nhwc {
|
||||
|
|
|
@ -28,9 +28,9 @@ constexpr int kShapeInputNum = 1;
|
|||
constexpr int kShapeOutputNum = 1;
|
||||
} // namespace
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; }
|
||||
std::vector<float> ConstantOfShape::GetValue() const { return this->primitive_->value.AsConstantOfShape()->value; }
|
||||
|
||||
void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; }
|
||||
int ConstantOfShape::GetDataType() const { return this->primitive_->value.AsConstantOfShape()->dataType; }
|
||||
|
||||
#else
|
||||
int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
|
@ -41,12 +41,22 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla
|
|||
MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value());
|
||||
std::vector<float> value;
|
||||
if (attr->value() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->value()->size()); i++) {
|
||||
value.push_back(attr->value()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateConstantOfShapeDirect(*fbb, attr->dataType(), &value);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
|
||||
std::vector<float> ConstantOfShape::GetValue() const {
|
||||
auto fb_vector = this->primitive_->value_as_ConstantOfShape()->value();
|
||||
return std::vector<float>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
int ConstantOfShape::GetDataType() const { return this->primitive_->value_as_ConstantOfShape()->dataType(); }
|
||||
|
||||
PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<ConstantOfShape>(primitive);
|
||||
|
@ -70,7 +80,7 @@ int ConstantOfShape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
|
|||
}
|
||||
auto in_tensor = inputs_.front();
|
||||
auto out_tensor = outputs_.front();
|
||||
out_tensor->set_data_type(kNumberTypeFloat32);
|
||||
out_tensor->set_data_type(static_cast<TypeId>(GetDataType()));
|
||||
out_tensor->SetFormat(in_tensor->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
|
|
|
@ -30,14 +30,14 @@ class ConstantOfShape : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC);
|
||||
ConstantOfShape() = default;
|
||||
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
void SetValue(float value);
|
||||
#else
|
||||
ConstantOfShape() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
float GetValue() const;
|
||||
std::vector<float> GetValue() const;
|
||||
int GetDataType() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,7 +34,13 @@ OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC
|
|||
}
|
||||
memset(param, 0, sizeof(ConstantOfShapeParameter));
|
||||
param->op_parameter_.type_ = primitive->Type();
|
||||
param->value_ = attr->GetValue();
|
||||
auto value = attr->GetValue();
|
||||
if (value.empty() || value.size() > 1) {
|
||||
MS_LOG(ERROR) << "The value of constant of shape is empty or more than 1.";
|
||||
} else {
|
||||
param->value_ = attr->GetValue()[0];
|
||||
}
|
||||
param->data_type_ = attr->GetDataType();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter);
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace lite {
|
|||
namespace {
|
||||
constexpr int kSliceInputNum = 1;
|
||||
constexpr int kSliceOutputNum = 1;
|
||||
constexpr int kSliceMaxInputNum = 5;
|
||||
} // namespace
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Slice::GetFormat() const { return this->primitive_->value.AsSlice()->format; }
|
||||
|
@ -175,6 +176,29 @@ int Slice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tens
|
|||
std::vector<int32_t> slice_size(GetSize());
|
||||
std::vector<int32_t> slice_axes(GetAxes());
|
||||
std::vector<int32_t> output_shape(input_shape.size());
|
||||
if (inputs.size() == kSliceMaxInputNum) {
|
||||
if (slice_begin.empty() && inputs.at(1)->data_c() != nullptr) {
|
||||
for (int i = 0; i < inputs.at(1)->ElementsNum(); i++) {
|
||||
slice_begin.emplace_back(static_cast<int *>(inputs.at(1)->data_c())[i]);
|
||||
}
|
||||
}
|
||||
if (slice_size.empty() && inputs.at(2)->data_c() != nullptr) {
|
||||
for (int i = 0; i < inputs.at(2)->ElementsNum(); i++) {
|
||||
auto end = static_cast<int *>(inputs.at(2)->data_c())[i];
|
||||
auto size = end < 0 ? end : (end == INT32_MAX ? -1 : end - slice_begin[i]);
|
||||
slice_size.emplace_back(size);
|
||||
}
|
||||
}
|
||||
if (slice_axes.empty() && inputs.at(3)->data_c() != nullptr) {
|
||||
for (int i = 0; i < inputs.at(3)->ElementsNum(); i++) {
|
||||
slice_axes.emplace_back(static_cast<int *>(inputs.at(3)->data_c())[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (slice_begin.empty() || slice_size.empty() || slice_axes.empty()) {
|
||||
MS_LOG(ERROR) << "Infershape failed.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
begin.assign(input_shape.size(), 0);
|
||||
size.assign(input_shape.size(), -1);
|
||||
for (size_t i = 0; i < slice_axes.size(); ++i) {
|
||||
|
|
|
@ -54,7 +54,7 @@ Registry TopKRegistry(schema::PrimitiveType_TopK, TopKCreator);
|
|||
|
||||
int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) {
|
||||
if ((inputs_.size() != kSingleNum && inputs_.size() != kDoubleNum) || outputs_.size() != kDoubleNum) {
|
||||
MS_LOG(ERROR) << "input size: " << inputs_.size() << ", output size: " << outputs_.size();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
@ -74,6 +74,9 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
MS_ASSERT(topk_prim != nullptr);
|
||||
auto out_shape = input->shape();
|
||||
out_shape[out_shape.size() - 1] = GetK();
|
||||
if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) {
|
||||
out_shape[out_shape.size() - 1] = reinterpret_cast<int *>(inputs_.at(1)->data_c())[0];
|
||||
}
|
||||
output0->set_shape(out_shape);
|
||||
output1->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
|
|
|
@ -104,9 +104,7 @@ int Unsqueeze::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
|
|||
out_shape.emplace_back(1);
|
||||
ax_itr++;
|
||||
} else {
|
||||
if (in_shape[in_itr] > 1) {
|
||||
out_shape.emplace_back(in_shape[in_itr]);
|
||||
}
|
||||
out_shape.emplace_back(in_shape[in_itr]);
|
||||
in_itr++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,6 +83,7 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|||
break;
|
||||
default:
|
||||
arithmetic_run_ = ElementAdd;
|
||||
arithmetic_run_int_ = ElementAddInt;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
|
|
|
@ -77,6 +77,9 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|||
} else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeFloat16) {
|
||||
Float32ToFp16(reinterpret_cast<float *>(input->data_c()) + offset,
|
||||
reinterpret_cast<uint16_t *>(output_data) + offset, data_num);
|
||||
} else if (input_data_type == kNumberTypeInt32 &&
|
||||
(output_data_type == kNumberTypeInt32 || output_data_type == kNumberTypeInt64)) {
|
||||
memcpy(output_data, input->data_c(), data_num * sizeof(int32_t));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -33,7 +33,18 @@ int ConstantOfShapeCPUKernel::Init() { return RET_OK; }
|
|||
int ConstantOfShapeCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
|
||||
int ret = ConstantOfShape(out_ptr_, task_id, param_);
|
||||
int ret = RET_ERROR;
|
||||
switch (param_->data_type_) {
|
||||
case kNumberTypeFloat32:
|
||||
ret = ConstantOfShape(reinterpret_cast<float *>(out_ptr_), task_id, param_);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
ret = ConstantOfShapeInt(reinterpret_cast<int32_t *>(out_ptr_), task_id, param_);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Constant of shape does not support the output data type.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConstantOfShapeRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return ret;
|
||||
|
@ -56,7 +67,17 @@ int ConstantOfShapeCPUKernel::Run() {
|
|||
int thread_num = MSMIN(param_->op_parameter_.thread_num_, param_->element_sz_);
|
||||
param_->unit_ = UP_DIV(param_->element_sz_, thread_num);
|
||||
param_->op_parameter_.thread_num_ = thread_num;
|
||||
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
switch (param_->data_type_) {
|
||||
case kNumberTypeFloat32:
|
||||
out_ptr_ = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
out_ptr_ = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData());
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Constant of shape does not support the output data type.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ConstantOfShapeRun, this, thread_num);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConstantOfShapeRun error error_code[" << ret << "]";
|
||||
|
@ -93,4 +114,5 @@ kernel::LiteKernel *CpuConstantOfShapeFp32KernelCreator(const std::vector<lite::
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, CpuConstantOfShapeFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, CpuConstantOfShapeFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -41,7 +41,7 @@ class ConstantOfShapeCPUKernel : public LiteKernel {
|
|||
|
||||
private:
|
||||
ConstantOfShapeParameter *param_;
|
||||
float *out_ptr_;
|
||||
void *out_ptr_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -74,8 +74,8 @@ int NonMaxSuppressionCPUKernel::GetParams() {
|
|||
max_output_per_class_ = 0;
|
||||
if (in_tensors_.size() >= 3) {
|
||||
auto max_output_tensor = in_tensors_.at(kMaxOutputNumTensorIndex);
|
||||
if (max_output_tensor != nullptr && reinterpret_cast<int64_t *>(max_output_tensor->data_c()) != nullptr) {
|
||||
max_output_per_class_ = *(reinterpret_cast<int64_t *>(max_output_tensor->data_c()));
|
||||
if (max_output_tensor != nullptr && reinterpret_cast<int32_t *>(max_output_tensor->data_c()) != nullptr) {
|
||||
max_output_per_class_ = *(reinterpret_cast<int32_t *>(max_output_tensor->data_c()));
|
||||
}
|
||||
}
|
||||
iou_threshold_ = 0.0f;
|
||||
|
|
|
@ -61,6 +61,7 @@ int ReduceCPUKernel::Init() {
|
|||
}
|
||||
case static_cast<int>(ReduceMode_ReduceMin): {
|
||||
reducer_ = ReduceMin;
|
||||
int_reducer_ = IntReduceMin;
|
||||
break;
|
||||
}
|
||||
case static_cast<int>(ReduceMode_ReduceProd): {
|
||||
|
|
|
@ -51,6 +51,14 @@ int TopKCPUKernel::Run() {
|
|||
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
TopkParameter *parameter = reinterpret_cast<TopkParameter *>(op_parameter_);
|
||||
if (in_tensors_.size() == lite::kDoubleNum) {
|
||||
auto input_k = reinterpret_cast<int *>(in_tensors_.at(1)->MutableData());
|
||||
parameter->k_ = input_k[0];
|
||||
}
|
||||
if (parameter->k_ > in_tensors_.at(0)->ElementsNum()) {
|
||||
MS_LOG(ERROR) << "The k value is out of the data size range.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter->topk_node_list_ = context_->allocator->Malloc(sizeof(TopkNode) * parameter->last_dim_size_);
|
||||
if (parameter->topk_node_list_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/parser/onnx/onnx_constant_of_shape_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -41,13 +42,25 @@ STATUS OnnxConstantOfShapeParser::Parse(const onnx::GraphProto &onnx_graph, cons
|
|||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "value") {
|
||||
if (onnx_node_attr.type() == onnx::AttributeProto_AttributeType_TENSOR) {
|
||||
auto tensor = onnx_node_attr.t();
|
||||
if (tensor.data_type() == onnx::AttributeProto_AttributeType_FLOAT) {
|
||||
attr->value = onnx_node_attr.f();
|
||||
} else if (tensor.data_type() == onnx::AttributeProto_AttributeType_INT) {
|
||||
attr->value = static_cast<int32_t>(onnx_node_attr.i());
|
||||
}
|
||||
switch (onnx_node_attr.type()) {
|
||||
case onnx::AttributeProto_AttributeType_FLOAT:
|
||||
attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
|
||||
attr->value.push_back(onnx_node_attr.f());
|
||||
break;
|
||||
case onnx::AttributeProto_AttributeType_INT:
|
||||
attr->dataType = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
|
||||
attr->value.push_back(static_cast<float>(onnx_node_attr.i()));
|
||||
break;
|
||||
case onnx::AttributeProto_AttributeType_TENSOR: {
|
||||
auto tensor = onnx_node_attr.t();
|
||||
auto ret = GetTensorDataFromOnnx(tensor, &attr->value, &attr->dataType);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "The data type is not supported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -445,11 +445,8 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|||
}
|
||||
for (size_t i = 0; i < data_count; ++i) {
|
||||
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
||||
if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) {
|
||||
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||
}
|
||||
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
|
||||
return RET_ERROR;
|
||||
MS_LOG(WARNING) << "int64 data " << in_data[i] << "too big to fit into int32";
|
||||
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
||||
} else {
|
||||
buffer[i] = static_cast<int>(in_data[i]);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -34,6 +35,36 @@ schema::PadMode OnnxNodeParser::GetOnnxPadMode(const onnx::AttributeProto &onnx_
|
|||
}
|
||||
}
|
||||
|
||||
STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value,
|
||||
int *type) {
|
||||
size_t data_count = 1;
|
||||
std::for_each(onnx_tensor.dims().begin(), onnx_tensor.dims().end(), [&data_count](int dim) { data_count *= dim; });
|
||||
switch (onnx_tensor.data_type()) {
|
||||
case onnx::TensorProto_DataType_FLOAT:
|
||||
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_FLOAT);
|
||||
for (size_t i = 0; i < data_count; i++) {
|
||||
value->push_back(reinterpret_cast<const float *>(onnx_tensor.raw_data().data())[i]);
|
||||
}
|
||||
break;
|
||||
case onnx::TensorProto_DataType_INT32:
|
||||
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
|
||||
for (size_t i = 0; i < data_count; i++) {
|
||||
value->push_back(static_cast<float>(reinterpret_cast<const int32_t *>(onnx_tensor.raw_data().data())[i]));
|
||||
}
|
||||
break;
|
||||
case onnx::TensorProto_DataType_INT64:
|
||||
*type = OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType_INT32);
|
||||
for (size_t i = 0; i < data_count; i++) {
|
||||
value->push_back(static_cast<float>(reinterpret_cast<const int64_t *>(onnx_tensor.raw_data().data())[i]));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "The data type is not supported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void OnnxNodeParser::Split(const std::string &src_str, std::vector<std::string> *dst_str, const std::string &chr) {
|
||||
std::string ::size_type p1 = 0, p2 = src_str.find(chr);
|
||||
while (std::string::npos != p2) {
|
||||
|
|
|
@ -35,6 +35,8 @@ class OnnxNodeParser {
|
|||
|
||||
virtual STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) = 0;
|
||||
|
||||
STATUS GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tensor, std::vector<float> *value, int *type);
|
||||
|
||||
protected:
|
||||
schema::PadMode GetOnnxPadMode(const onnx::AttributeProto &onnx_node_attr);
|
||||
|
||||
|
|
Loading…
Reference in New Issue