forked from OSSInnovation/mindspore
!3888 fix anf exporter
Merge pull request !3888 from wangchangkai/master
This commit is contained in:
commit
282c5415f1
|
@ -241,6 +241,9 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
|
|||
nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
||||
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
||||
meta_graph->allTensors.emplace_back(std::move(paramTensor));
|
||||
} else if (value->isa<mindspore::ValueSequeue>()) {
|
||||
MS_LOG(INFO) << "Value type is ValueSequence.";
|
||||
break;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
}
|
||||
|
|
|
@ -21,6 +21,10 @@
|
|||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr int kReduceInputNum = 3;
|
||||
constexpr int kReduceInputIndex = 2;
|
||||
}
|
||||
int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto p = GetCNodePrimitive(cnodePtr);
|
||||
|
@ -28,7 +32,25 @@ int mindspore::lite::AnfReduceMeanPopulater::Parse(CNodePtr cnodePtr, schema::CN
|
|||
attr->mode = schema::ReduceMode_ReduceMean;
|
||||
|
||||
attr->keepDims = GetValue<bool>(p->GetAttr("keep_dims"));
|
||||
// attr->axes = GetValue<std::vector<int>>(p->GetAttr("shape"));
|
||||
if (cnodePtr->inputs().size() == kReduceInputNum) {
|
||||
auto inputNode = cnodePtr->input(kReduceInputIndex);
|
||||
MS_ASSERT(inputNode != nullptr);
|
||||
if (inputNode->isa<ValueNode>()) {
|
||||
auto valueNode = inputNode->cast<ValueNodePtr>();
|
||||
MS_ASSERT(valueNode != nullptr);
|
||||
auto value = valueNode->value();
|
||||
MS_ASSERT(value != nullptr);
|
||||
if (value->isa<ValueTuple>()) {
|
||||
auto valTuplPtr = dyn_cast<ValueTuple>(value);
|
||||
MS_ASSERT(valTuplPtr != nullptr);
|
||||
for (size_t i = 0; i < valTuplPtr->size(); i++) {
|
||||
auto elem = dyn_cast<Int32Imm>((*valTuplPtr)[i]);
|
||||
MS_ASSERT(elem != nullptr);
|
||||
attr->axes.emplace_back(elem->value());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
|
|
Loading…
Reference in New Issue