adjust concat and remove unused attr

This commit is contained in:
xuanyue 2020-11-20 11:00:31 +08:00
parent 2bdc6198aa
commit af56b2fe80
7 changed files with 3 additions and 12 deletions

View File

@ -168,7 +168,7 @@ table FlattenGrad {
} }
table Concat { table Concat {
axis: int; axis: int;
n: int; n: int; // DEPRECATED
} }
table SoftMax { table SoftMax {
@ -822,6 +822,7 @@ table Gather {
} }
table GatherNd { table GatherNd {
batchDims: int; // DEPRECATED
} }
table Fill { table Fill {

View File

@ -27,10 +27,8 @@ namespace mindspore {
namespace lite { namespace lite {
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; } int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; }
int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; }
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; } void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) { if (this->primitive_ == nullptr) {
@ -71,13 +69,12 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
MS_LOG(ERROR) << "value_as_Concat return nullptr"; MS_LOG(ERROR) << "value_as_Concat return nullptr";
return RET_ERROR; return RET_ERROR;
} }
auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n()); auto val_offset = schema::CreateConcat(*fbb, attr->axis());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
fbb->Finish(prim_offset); fbb->Finish(prim_offset);
return RET_OK; return RET_OK;
} }
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); } PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); }
Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator); Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator);

View File

@ -33,13 +33,11 @@ class Concat : public PrimitiveC {
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(int axis); void SetAxis(int axis);
void SetN(int n);
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetAxis() const; int GetAxis() const;
int GetN() const;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -93,7 +93,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsTwoBranch) {
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
auto concat_primitive = new mindspore::schema::ConcatT; auto concat_primitive = new mindspore::schema::ConcatT;
concat_primitive->axis = 3; concat_primitive->axis = 3;
concat_primitive->n = 2;
concat->primitive->value.value = concat_primitive; concat->primitive->value.value = concat_primitive;
concat->name = "concat"; concat->name = "concat";
@ -255,7 +254,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsThreeBranch) {
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat; concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
auto concat_primitive = new mindspore::schema::ConcatT; auto concat_primitive = new mindspore::schema::ConcatT;
concat_primitive->axis = 3; concat_primitive->axis = 3;
concat_primitive->n = 2;
concat->primitive->value.value = concat_primitive; concat->primitive->value.value = concat_primitive;
concat->name = "concat"; concat->name = "concat";

View File

@ -35,7 +35,6 @@ TEST_F(TestTfliteParserConcat, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr); ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsConcat(); auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
ASSERT_EQ(val->axis, 1); ASSERT_EQ(val->axis, 1);
ASSERT_EQ(val->n, 2);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -60,7 +60,6 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe:
MS_LOG(DEBUG) << "by default, set axis = 1"; MS_LOG(DEBUG) << "by default, set axis = 1";
attr->axis = 1; attr->axis = 1;
} }
attr->n = proto.bottom_size();
op->name = proto.name(); op->name = proto.name();
op->primitive->value.type = schema::PrimitiveType_Concat; op->primitive->value.type = schema::PrimitiveType_Concat;

View File

@ -34,7 +34,6 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite:
return nullptr; return nullptr;
} }
attr->axis = tfliteAttr->axis; attr->axis = tfliteAttr->axis;
attr->n = tflite_op->inputs.size();
primitive->value.type = schema::PrimitiveType_Concat; primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release()); return PrimitiveC::Create(primitive.release());