fix broadcast parser of onnx

This commit is contained in:
yankai 2020-10-28 14:33:14 +08:00
parent 116173d130
commit 4824727a63
2 changed files with 36 additions and 5 deletions

View File

@ -62,21 +62,32 @@ Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreat
namespace { namespace {
constexpr int kBroadcastToInputNum = 1; constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOnnxInputNum = 2;
constexpr int kBroadcastToOutputNum = 1; constexpr int kBroadcastToOutputNum = 1;
} // namespace } // namespace
int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) { if (inputs.size() != kBroadcastToInputNum && inputs.size() != kBroadcastToOnnxInputNum) {
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size(); MS_LOG(ERROR) << "input size:" << inputs.size();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (outputs.size() != kBroadcastToOutputNum) {
MS_LOG(ERROR) << "output size:" << outputs.size();
return RET_PARAM_INVALID;
}
auto input = inputs.at(0); auto input = inputs.at(0);
outputs[0]->SetFormat(input->GetFormat()); outputs[0]->SetFormat(input->GetFormat());
outputs[0]->set_data_type(input->data_type()); outputs[0]->set_data_type(input->data_type());
if (!GetInferFlag()) { if (!GetInferFlag()) {
return RET_OK; return RET_OK;
} }
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end()); std::vector<int32_t> dst_shape(GetDstShape());
for (size_t i = 0; i < dst_shape.size(); ++i) {
if (dst_shape[i] == -1) {
dst_shape[i] = inputs[0]->shape()[i];
}
}
auto input_shape = input->shape(); auto input_shape = input->shape();
std::vector<int> shape(dst_shape.size()); std::vector<int> shape(dst_shape.size());
int input_shape_index = input_shape.size() - 1; int input_shape_index = input_shape.size() - 1;

View File

@ -32,13 +32,33 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>(); std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
op->primitive->value.type = schema::PrimitiveType_Broadcast; std::vector<int> dst_shape;
const auto &onnx_expand_power = onnx_node.input(1);
auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
if (nodeIter == onnx_graph.node().end()) {
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
return RET_ERROR;
}
const int64_t *dataPtr = nullptr;
for (const auto &attrPower : nodeIter->attribute()) {
if (attrPower.name() == "value") {
const auto &t = attrPower.t();
dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data());
for (int i = 0; i < t.dims(0); ++i) {
dst_shape.emplace_back(dataPtr[i]);
}
}
}
attr->dst_shape = dst_shape;
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;
} }