forked from mindspore-Ecosystem/mindspore
fix broadcast parser of onnx
This commit is contained in:
parent
116173d130
commit
4824727a63
|
@ -62,21 +62,32 @@ Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreat
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kBroadcastToInputNum = 1;
|
constexpr int kBroadcastToInputNum = 1;
|
||||||
|
constexpr int kBroadcastToOnnxInputNum = 2;
|
||||||
constexpr int kBroadcastToOutputNum = 1;
|
constexpr int kBroadcastToOutputNum = 1;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
int BroadcastTo::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
||||||
if (inputs.size() != kBroadcastToInputNum || outputs.size() != kBroadcastToOutputNum) {
|
if (inputs.size() != kBroadcastToInputNum && inputs.size() != kBroadcastToOnnxInputNum) {
|
||||||
MS_LOG(ERROR) << "input size:" << inputs.size() << ", output size:" << outputs.size();
|
MS_LOG(ERROR) << "input size:" << inputs.size();
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
|
if (outputs.size() != kBroadcastToOutputNum) {
|
||||||
|
MS_LOG(ERROR) << "output size:" << outputs.size();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
auto input = inputs.at(0);
|
auto input = inputs.at(0);
|
||||||
outputs[0]->SetFormat(input->GetFormat());
|
outputs[0]->SetFormat(input->GetFormat());
|
||||||
outputs[0]->set_data_type(input->data_type());
|
outputs[0]->set_data_type(input->data_type());
|
||||||
if (!GetInferFlag()) {
|
if (!GetInferFlag()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
std::vector<int32_t> dst_shape(GetDstShape().begin(), GetDstShape().end());
|
std::vector<int32_t> dst_shape(GetDstShape());
|
||||||
|
for (size_t i = 0; i < dst_shape.size(); ++i) {
|
||||||
|
if (dst_shape[i] == -1) {
|
||||||
|
dst_shape[i] = inputs[0]->shape()[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
auto input_shape = input->shape();
|
auto input_shape = input->shape();
|
||||||
std::vector<int> shape(dst_shape.size());
|
std::vector<int> shape(dst_shape.size());
|
||||||
int input_shape_index = input_shape.size() - 1;
|
int input_shape_index = input_shape.size() - 1;
|
||||||
|
|
|
@ -32,13 +32,33 @@ STATUS OnnxExpandParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<schema::BroadcastT> attr = std::make_unique<schema::BroadcastT>();
|
std::unique_ptr<schema::BroadcastToT> attr = std::make_unique<schema::BroadcastToT>();
|
||||||
if (attr == nullptr) {
|
if (attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new op failed";
|
MS_LOG(ERROR) << "new op failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
op->primitive->value.type = schema::PrimitiveType_Broadcast;
|
std::vector<int> dst_shape;
|
||||||
|
const auto &onnx_expand_power = onnx_node.input(1);
|
||||||
|
auto nodeIter =
|
||||||
|
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),
|
||||||
|
[onnx_expand_power](const onnx::NodeProto &proto) { return proto.output(0) == onnx_expand_power; });
|
||||||
|
if (nodeIter == onnx_graph.node().end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find node: " << onnx_expand_power;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
const int64_t *dataPtr = nullptr;
|
||||||
|
for (const auto &attrPower : nodeIter->attribute()) {
|
||||||
|
if (attrPower.name() == "value") {
|
||||||
|
const auto &t = attrPower.t();
|
||||||
|
dataPtr = reinterpret_cast<const int64_t *>(t.raw_data().data());
|
||||||
|
for (int i = 0; i < t.dims(0); ++i) {
|
||||||
|
dst_shape.emplace_back(dataPtr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr->dst_shape = dst_shape;
|
||||||
|
op->primitive->value.type = schema::PrimitiveType_BroadcastTo;
|
||||||
op->primitive->value.value = attr.release();
|
op->primitive->value.value = attr.release();
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue