fix mindspore models runtime on_device

This commit is contained in:
yankai 2020-08-11 16:49:43 +08:00
parent 6a5c00ff7a
commit b3468fab89
6 changed files with 32 additions and 7 deletions

View File

@ -385,9 +385,11 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schem
int i = 0;
for (auto outputTensor : outputTensors) {
std::string name = cnodeName + "_o:" + std::to_string(i);
auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter;
nodeIdMap[name] = graph->allTensors.size();
fbnode->outputIndex.emplace_back(graph->allTensors.size());
graph->allTensors.emplace_back(outputTensor);
graph->allTensors.emplace_back(msTensor);
i++;
}
return;

View File

@ -23,10 +23,28 @@
namespace mindspore::lite {
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::FlattenT>();
auto attr = std::make_unique<schema::ReshapeT>();
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
if (inputNode->isa<ValueNode>()) {
auto valueNode = inputNode->cast<ValueNodePtr>();
MS_ASSERT(valueNode != nullptr);
auto val = valueNode->value();
MS_ASSERT(val != nullptr);
if (val->isa<ValueTuple>()) {
auto tuple = val->cast<ValueTuplePtr>();
MS_ASSERT(tuple != nullptr);
for (size_t i = 0; i < tuple->size(); ++i) {
auto elem = tuple->value()[i]->cast<Int32ImmPtr>();
MS_ASSERT(elem != nullptr);
attr->shape.emplace_back(static_cast<int>(elem->value()));
}
}
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.type = schema::PrimitiveType_Reshape;
node->primitive->value.value = attr.release();
return 0;
}

View File

@ -639,7 +639,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif
#else
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \
@ -1108,6 +1108,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc
BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr);
return true;
}
#endif
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);

View File

@ -77,7 +77,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
const onnx::TensorProto &attr_tensor);
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
#else
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
@ -100,6 +100,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
private:

View File

@ -232,7 +232,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
if (type == kCKHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterK) + (w * filterC) + (c));
} else {

View File

@ -350,6 +350,9 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
@ -362,7 +365,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition