fix_concat_slice

This commit is contained in:
sunsuodong 2020-08-31 17:13:47 +08:00
parent 7371ceddde
commit c5e47b1136
2 changed files with 7 additions and 7 deletions

View File

@ -107,14 +107,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
}
auto input0_shape_without_axis = input0_shape;
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
auto input0_data_type = inputs_.at(0)->data_type();
int output_axis_dim = input0_shape.at(axis);
for (size_t i = 1; i < inputs_.size(); ++i) {
if (inputs_.at(i)->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All inputs should have the same data type!";
return RET_PARAM_INVALID;
}
auto shape_tmp = inputs_.at(i)->shape();
if (shape_tmp.size() != input0_shape.size()) {
MS_LOG(ERROR) << "All inputs should have the same dim num!";

View File

@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
return RET_ERROR;
}
std::vector<int32_t> axes;
if (attr->axes() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axes()->size()); i++) {
axes.push_back(attr->axes()->data()[i]);
}
}
std::vector<int32_t> begin;
if (attr->begin() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
}
}
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size);
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &axes, &begin, &size);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;