forked from mindspore-Ecosystem/mindspore
adjust concat and remove unused attr
This commit is contained in:
parent
2bdc6198aa
commit
af56b2fe80
|
@ -168,7 +168,7 @@ table FlattenGrad {
|
|||
}
|
||||
table Concat {
|
||||
axis: int;
|
||||
n: int;
|
||||
n: int; // DEPRECATED
|
||||
}
|
||||
|
||||
table SoftMax {
|
||||
|
@ -822,6 +822,7 @@ table Gather {
|
|||
}
|
||||
|
||||
table GatherNd {
|
||||
batchDims: int; // DEPRECATED
|
||||
}
|
||||
|
||||
table Fill {
|
||||
|
|
|
@ -27,10 +27,8 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; }
|
||||
int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; }
|
||||
|
||||
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
|
||||
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
|
||||
|
||||
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
|
@ -71,13 +69,12 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
|||
MS_LOG(ERROR) << "value_as_Concat return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n());
|
||||
auto val_offset = schema::CreateConcat(*fbb, attr->axis());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
|
||||
|
||||
PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); }
|
||||
Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator);
|
||||
|
|
|
@ -33,13 +33,11 @@ class Concat : public PrimitiveC {
|
|||
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAxis(int axis);
|
||||
void SetN(int n);
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
int GetAxis() const;
|
||||
int GetN() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -93,7 +93,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsTwoBranch) {
|
|||
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
||||
auto concat_primitive = new mindspore::schema::ConcatT;
|
||||
concat_primitive->axis = 3;
|
||||
concat_primitive->n = 2;
|
||||
concat->primitive->value.value = concat_primitive;
|
||||
concat->name = "concat";
|
||||
|
||||
|
@ -255,7 +254,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsThreeBranch) {
|
|||
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
||||
auto concat_primitive = new mindspore::schema::ConcatT;
|
||||
concat_primitive->axis = 3;
|
||||
concat_primitive->n = 2;
|
||||
concat->primitive->value.value = concat_primitive;
|
||||
concat->name = "concat";
|
||||
|
||||
|
|
|
@ -35,7 +35,6 @@ TEST_F(TestTfliteParserConcat, AttrValue) {
|
|||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
|
||||
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
|
||||
ASSERT_EQ(val->axis, 1);
|
||||
ASSERT_EQ(val->n, 2);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -60,7 +60,6 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
|||
MS_LOG(DEBUG) << "by default, set axis = 1";
|
||||
attr->axis = 1;
|
||||
}
|
||||
attr->n = proto.bottom_size();
|
||||
|
||||
op->name = proto.name();
|
||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||
|
|
|
@ -34,7 +34,6 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite:
|
|||
return nullptr;
|
||||
}
|
||||
attr->axis = tfliteAttr->axis;
|
||||
attr->n = tflite_op->inputs.size();
|
||||
primitive->value.type = schema::PrimitiveType_Concat;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
|
|
Loading…
Reference in New Issue