forked from mindspore-Ecosystem/mindspore
fix mindspore models runtime on_device
This commit is contained in:
parent
6a5c00ff7a
commit
b3468fab89
|
@ -385,9 +385,11 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schem
|
|||
int i = 0;
|
||||
for (auto outputTensor : outputTensors) {
|
||||
std::string name = cnodeName + "_o:" + std::to_string(i);
|
||||
auto msTensor = new schema::TensorT();
|
||||
msTensor->nodeType = schema::NodeType_Parameter;
|
||||
nodeIdMap[name] = graph->allTensors.size();
|
||||
fbnode->outputIndex.emplace_back(graph->allTensors.size());
|
||||
graph->allTensors.emplace_back(outputTensor);
|
||||
graph->allTensors.emplace_back(msTensor);
|
||||
i++;
|
||||
}
|
||||
return;
|
||||
|
|
|
@ -23,10 +23,28 @@
|
|||
namespace mindspore::lite {
|
||||
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
|
||||
std::vector<schema::TensorT *> *outputs) {
|
||||
auto attr = std::make_unique<schema::FlattenT>();
|
||||
auto attr = std::make_unique<schema::ReshapeT>();
|
||||
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
|
||||
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
|
||||
if (inputNode->isa<ValueNode>()) {
|
||||
auto valueNode = inputNode->cast<ValueNodePtr>();
|
||||
MS_ASSERT(valueNode != nullptr);
|
||||
auto val = valueNode->value();
|
||||
MS_ASSERT(val != nullptr);
|
||||
if (val->isa<ValueTuple>()) {
|
||||
auto tuple = val->cast<ValueTuplePtr>();
|
||||
MS_ASSERT(tuple != nullptr);
|
||||
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||
auto elem = tuple->value()[i]->cast<Int32ImmPtr>();
|
||||
MS_ASSERT(elem != nullptr);
|
||||
attr->shape.emplace_back(static_cast<int>(elem->value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node->nodeType = schema::NodeType_CNode;
|
||||
node->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
node->primitive->value.type = schema::PrimitiveType_Flatten;
|
||||
node->primitive->value.type = schema::PrimitiveType_Reshape;
|
||||
node->primitive->value.value = attr.release();
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -639,7 +639,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
|
||||
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
|
||||
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
|
||||
|
@ -1108,6 +1108,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
|
|||
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
|
|
|
@ -77,7 +77,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
const onnx::TensorProto &attr_tensor);
|
||||
std::unordered_map<std::string, abstract::AbstractTensorPtr>
|
||||
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
#endif
|
||||
#else
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
|
||||
|
@ -100,6 +100,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
|
|||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
|
||||
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
private:
|
||||
|
|
|
@ -232,7 +232,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
|
|||
if (type == kCKHW2HWCK) {
|
||||
p2Buff =
|
||||
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
|
||||
} else if (type == kKCHW2KHWC) {
|
||||
} else if (type == kCKHW2KHWC) {
|
||||
p2Buff =
|
||||
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
|
||||
} else {
|
||||
|
|
|
@ -350,6 +350,9 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
// todo(00445839): consider varible weight condition
|
||||
}
|
||||
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
|
||||
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
|
||||
weightTensor->format = schema::Format_CKHW;
|
||||
}
|
||||
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
|
||||
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
|
||||
} else if (weightTensor->format == schema::Format_KCHW) {
|
||||
|
@ -362,7 +365,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|||
}
|
||||
if (status == 0) {
|
||||
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
|
||||
weightTensor->format = schema::Format_CKHW;
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
|
||||
// todo(00445839): consider varible weight condition
|
||||
|
|
Loading…
Reference in New Issue