fix anf exporter

This commit is contained in:
kai00 2020-08-03 17:14:55 +08:00
parent 02c921550e
commit 39ac3273a8
2 changed files with 26 additions and 1 deletions

View File

@ -241,6 +241,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor));
} else if (value->isa<mindspore::ValueSequeue>()) {
MS_LOG(INFO) << "Value type is ValueSequence.";
break;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
}

View File

@ -21,6 +21,10 @@
#include "ir/primitive.h"
namespace mindspore::lite {
namespace {
constexpr int kReduceInputNum = 3;
constexpr int kReduceInputIndex = 2;
}
int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
@ -28,7 +32,25 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN
attr->mode = schema::ReduceMode_ReduceMean;
attr->keepDims = GetValue<bool>(p->GetAttr("keep_dims"));
// attr->axes = GetValue<std::vector<int>>(p->GetAttr("shape"));
if (cnodePtr->inputs().size() == kReduceInputNum) {
auto inputNode = cnodePtr->input(kReduceInputIndex);
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto value = valueNode->value();
MS_ASSERT(value != nullptr);
if (value->isa<ValueTuple>()) {
auto valTuplPtr = dyn_cast<ValueTuple>(value);
MS_ASSERT(valTuplPtr != nullptr);
for (size_t i = 0; i < valTuplPtr->size(); i++) {
auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]);
MS_ASSERT(elem != nullptr);
attr->axes.emplace_back(elem->value());
}
}
}
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();