forked from mindspore-Ecosystem/mindspore
fix broadcast parser of onnx
This commit is contained in:
parent
116173d130
commit
4824727a63
|
@ -62,21 +62,32 @@ Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreat
|
|||
|
||||
namespace {
|
||||
constexpr int kBroadcastToInputNum = 1;
|
||||
constexpr int kBroadcastToOnnxInputNum = 2;
|
||||
constexpr int kBroadcastToOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
|
||||
if (inputs.size() != kBroadcastToInputNum && inputs.size() != kBroadcastToOnnxInputNum) {
|
||||
MS_LOG(ERROR) << "input size:" << inputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
if (outputs.size() != kBroadcastToOutputNum) {
|
||||
MS_LOG(ERROR) << "output size:" << outputs.size();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end());
|
||||
std::vector<int32_t> dst_shape(GetDstShape());
|
||||
for (size_t i = 0; i < dst_shape.size(); ++i) {
|
||||
if (dst_shape[i] == -1) {
|
||||
dst_shape[i] = inputs[0]->shape()[i];
|
||||
}
|
||||
}
|
||||
auto input_shape = input->shape();
|
||||
std::vector<int> shape(dst_shape.size());
|
||||
int input_shape_index = input_shape.size() - 1;
|
||||
|
|
|
@ -32,13 +32,33 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
|
||||
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Broadcast;
|
||||
std::vector<int> dst_shape;
|
||||
const auto &onnx_expand_power = onnx_node.input(1);
|
||||
auto nodeIter =
|
||||
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
|
||||
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
|
||||
if (nodeIter == onnx_graph.node().end()) {
|
||||
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
|
||||
return RET_ERROR;
|
||||
}
|
||||
const int64_t *dataPtr = nullptr;
|
||||
for (const auto &attrPower : nodeIter->attribute()) {
|
||||
if (attrPower.name() == "value") {
|
||||
const auto &t = attrPower.t();
|
||||
dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data());
|
||||
for (int i = 0; i < t.dims(0); ++i) {
|
||||
dst_shape.emplace_back(dataPtr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
attr->dst_shape = dst_shape;
|
||||
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue