forked from mindspore-Ecosystem/mindspore
adjust onnx upsample
This commit is contained in:
parent
d616ce1a2a
commit
f8267c821d
|
@ -352,7 +352,7 @@ table FakeQuantWithMinMaxVars {
|
|||
}
|
||||
|
||||
table BiasAdd {
|
||||
axis: [int];
|
||||
axis: [int]; // DEPRECATED
|
||||
}
|
||||
|
||||
table ROIPooling {
|
||||
|
|
|
@ -24,10 +24,6 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBiasAdd()->axis; }
|
||||
|
||||
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
|
||||
|
||||
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
|
@ -67,21 +63,11 @@ int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
MS_LOG(ERROR) << "value_as_BiasAdd return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int32_t> axis;
|
||||
if (attr->axis() != nullptr) {
|
||||
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
|
||||
axis.push_back(attr->axis()->data()[i]);
|
||||
}
|
||||
}
|
||||
auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis);
|
||||
auto val_offset = schema::CreateBiasAddDirect(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int> BiasAdd::GetAxis() const {
|
||||
auto fb_vector = this->primitive_->value_as_BiasAdd()->axis();
|
||||
return std::vector<int>(fb_vector->begin(), fb_vector->end());
|
||||
}
|
||||
|
||||
PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); }
|
||||
Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator);
|
||||
|
|
|
@ -33,11 +33,9 @@ class BiasAdd : public PrimitiveC {
|
|||
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
|
||||
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
void SetAxis(const std::vector<int> &axis);
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
std::vector<int> GetAxis() const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,22 +38,17 @@ STATUS OnnxUpsampleParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "mode") {
|
||||
if ("nearest" == onnx_node_attr.s()) {
|
||||
attr->method = schema::ResizeMethod_NEAREST;
|
||||
} else if ("bilinear" == onnx_node_attr.s()) {
|
||||
attr->method = schema::ResizeMethod_LINEAR;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Resize do not support upsample mode";
|
||||
return RET_ERROR;
|
||||
if (onnx_node_attr.s() != "nearest" && onnx_node_attr.s() != "linear") {
|
||||
MS_LOG(ERROR) << "the upsample mode don't support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
attr->method = onnx_node_attr.s() == "nearest" ? schema::ResizeMethod_NEAREST : schema::ResizeMethod_LINEAR;
|
||||
}
|
||||
}
|
||||
attr->newWidth = 1;
|
||||
attr->newHeight = 1;
|
||||
attr->alignCorners = false;
|
||||
op->primitive->value.type = schema::PrimitiveType_Resize;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue