forked from mindspore-Ecosystem/mindspore
adjust concat and remove unused attr
This commit is contained in:
parent
2bdc6198aa
commit
af56b2fe80
|
@ -168,7 +168,7 @@ table FlattenGrad {
|
||||||
}
|
}
|
||||||
table Concat {
|
table Concat {
|
||||||
axis: int;
|
axis: int;
|
||||||
n: int;
|
n: int; // DEPRECATED
|
||||||
}
|
}
|
||||||
|
|
||||||
table SoftMax {
|
table SoftMax {
|
||||||
|
@ -822,6 +822,7 @@ table Gather {
|
||||||
}
|
}
|
||||||
|
|
||||||
table GatherNd {
|
table GatherNd {
|
||||||
|
batchDims: int; // DEPRECATED
|
||||||
}
|
}
|
||||||
|
|
||||||
table Fill {
|
table Fill {
|
||||||
|
|
|
@ -27,10 +27,8 @@ namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; }
|
int Concat::GetAxis() const { return this->primitive_->value.AsConcat()->axis; }
|
||||||
int Concat::GetN() const { return this->primitive_->value.AsConcat()->n; }
|
|
||||||
|
|
||||||
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
|
void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis; }
|
||||||
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
|
|
||||||
|
|
||||||
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
if (this->primitive_ == nullptr) {
|
if (this->primitive_ == nullptr) {
|
||||||
|
@ -71,13 +69,12 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
|
||||||
MS_LOG(ERROR) << "value_as_Concat return nullptr";
|
MS_LOG(ERROR) << "value_as_Concat return nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n());
|
auto val_offset = schema::CreateConcat(*fbb, attr->axis());
|
||||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
|
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
|
||||||
fbb->Finish(prim_offset);
|
fbb->Finish(prim_offset);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
|
||||||
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
|
|
||||||
|
|
||||||
PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); }
|
PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Concat>(primitive); }
|
||||||
Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator);
|
Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator);
|
||||||
|
|
|
@ -33,13 +33,11 @@ class Concat : public PrimitiveC {
|
||||||
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||||
void SetAxis(int axis);
|
void SetAxis(int axis);
|
||||||
void SetN(int n);
|
|
||||||
#else
|
#else
|
||||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||||
#endif
|
#endif
|
||||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||||
int GetAxis() const;
|
int GetAxis() const;
|
||||||
int GetN() const;
|
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -93,7 +93,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsTwoBranch) {
|
||||||
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
||||||
auto concat_primitive = new mindspore::schema::ConcatT;
|
auto concat_primitive = new mindspore::schema::ConcatT;
|
||||||
concat_primitive->axis = 3;
|
concat_primitive->axis = 3;
|
||||||
concat_primitive->n = 2;
|
|
||||||
concat->primitive->value.value = concat_primitive;
|
concat->primitive->value.value = concat_primitive;
|
||||||
concat->name = "concat";
|
concat->name = "concat";
|
||||||
|
|
||||||
|
@ -255,7 +254,6 @@ TEST_F(SchedulerTest, TestConstructSubGraphsThreeBranch) {
|
||||||
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
concat->primitive->value.type = mindspore::schema::PrimitiveType_Concat;
|
||||||
auto concat_primitive = new mindspore::schema::ConcatT;
|
auto concat_primitive = new mindspore::schema::ConcatT;
|
||||||
concat_primitive->axis = 3;
|
concat_primitive->axis = 3;
|
||||||
concat_primitive->n = 2;
|
|
||||||
concat->primitive->value.value = concat_primitive;
|
concat->primitive->value.value = concat_primitive;
|
||||||
concat->name = "concat";
|
concat->name = "concat";
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,6 @@ TEST_F(TestTfliteParserConcat, AttrValue) {
|
||||||
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
|
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsConcat(), nullptr);
|
||||||
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
|
auto val = meta_graph->nodes.front()->primitive->value.AsConcat();
|
||||||
ASSERT_EQ(val->axis, 1);
|
ASSERT_EQ(val->axis, 1);
|
||||||
ASSERT_EQ(val->n, 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -60,7 +60,6 @@ STATUS CaffeConcatParser::Parse(const caffe::LayerParameter &proto, const caffe:
|
||||||
MS_LOG(DEBUG) << "by default, set axis = 1";
|
MS_LOG(DEBUG) << "by default, set axis = 1";
|
||||||
attr->axis = 1;
|
attr->axis = 1;
|
||||||
}
|
}
|
||||||
attr->n = proto.bottom_size();
|
|
||||||
|
|
||||||
op->name = proto.name();
|
op->name = proto.name();
|
||||||
op->primitive->value.type = schema::PrimitiveType_Concat;
|
op->primitive->value.type = schema::PrimitiveType_Concat;
|
||||||
|
|
|
@ -34,7 +34,6 @@ PrimitiveC *TfliteConcatParser::ParseLitePrimitive(const std::unique_ptr<tflite:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
attr->axis = tfliteAttr->axis;
|
attr->axis = tfliteAttr->axis;
|
||||||
attr->n = tflite_op->inputs.size();
|
|
||||||
primitive->value.type = schema::PrimitiveType_Concat;
|
primitive->value.type = schema::PrimitiveType_Concat;
|
||||||
primitive->value.value = attr.release();
|
primitive->value.value = attr.release();
|
||||||
return PrimitiveC::Create(primitive.release());
|
return PrimitiveC::Create(primitive.release());
|
||||||
|
|
Loading…
Reference in New Issue