forked from mindspore-Ecosystem/mindspore
fix_concat_slice
This commit is contained in:
parent
7371ceddde
commit
c5e47b1136
|
@ -107,14 +107,8 @@ int Concat::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
|
|||
}
|
||||
auto input0_shape_without_axis = input0_shape;
|
||||
input0_shape_without_axis.erase(input0_shape_without_axis.begin() + axis);
|
||||
auto input0_data_type = inputs_.at(0)->data_type();
|
||||
int output_axis_dim = input0_shape.at(axis);
|
||||
for (size_t i = 1; i < inputs_.size(); ++i) {
|
||||
if (inputs_.at(i)->data_type() != input0_data_type) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same data type!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto shape_tmp = inputs_.at(i)->shape();
|
||||
if (shape_tmp.size() != input0_shape.size()) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same dim num!";
|
||||
|
|
|
@ -60,6 +60,12 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<int32_t> axes;
|
||||
if (attr->axes() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->axes()->size()); i++) {
|
||||
axes.push_back(attr->axes()->data()[i]);
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> begin;
|
||||
if (attr->begin() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->begin()->size()); i++) {
|
||||
|
@ -73,7 +79,7 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
|
|||
}
|
||||
}
|
||||
|
||||
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &begin, &size);
|
||||
auto val_offset = schema::CreateSliceDirect(*fbb, attr->format(), &axes, &begin, &size);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Slice, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue