pr to master #8
|
@ -171,10 +171,12 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
|
||||
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node, schema::MetaGraphT *graph,
|
||||
TensorCache *tensor_cache, const QuantType &quant_type) {
|
||||
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
|
||||
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
|
||||
dst_op_1->quantType = quant_type;
|
||||
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
|
||||
auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0);
|
||||
std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)};
|
||||
|
@ -185,6 +187,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
|
|||
|
||||
std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>();
|
||||
dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0);
|
||||
dst_op_2->quantType = quant_type;
|
||||
ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get());
|
||||
std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)};
|
||||
std::vector<string> biasadd_outputs{onnx_node.output(0)};
|
||||
|
@ -343,8 +346,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
|
|||
}
|
||||
if (findQuantParams == needQuantParams) {
|
||||
dst_op->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
dst_op->quantType = schema::QuantType_QUANT_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
}
|
||||
if (onnx_node.op_type() == "Gemm") {
|
||||
if (status == RET_OK) {
|
||||
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache);
|
||||
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType);
|
||||
}
|
||||
continue;
|
||||
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
|
||||
|
|
|
@ -65,7 +65,7 @@ class OnnxModelParser : public ModelParser {
|
|||
const QuantType &quantType);
|
||||
|
||||
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache);
|
||||
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
|
||||
|
||||
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);
|
||||
|
||||
|
|
|
@ -47,14 +47,11 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
|
|||
|
||||
STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
|
||||
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
|
||||
auto count = 1;
|
||||
std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; });
|
||||
auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType));
|
||||
auto buffer_idx = tflite_tensor->buffer;
|
||||
if (!tflite_model_buffer[buffer_idx]->data.empty()) {
|
||||
auto data_size = tflite_model_buffer[buffer_idx]->data.size();
|
||||
tensor->data.resize(data_size);
|
||||
if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(),
|
||||
tflite_model_buffer[buffer_idx]->data.size())) {
|
||||
if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy tensor data failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
|
|
@ -119,6 +119,9 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
|
|||
{tflite::BuiltinOperator_CUSTOM, "Custom"},
|
||||
{tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"},
|
||||
{tflite::BuiltinOperator_NEG, "Neg"},
|
||||
{tflite::BuiltinOperator_HASHTABLE_LOOKUP, "HashtableLookup"},
|
||||
{tflite::BuiltinOperator_LSH_PROJECTION, "LshProjection"},
|
||||
{tflite::BuiltinOperator_SKIP_GRAM, "SKipGram"},
|
||||
};
|
||||
|
||||
std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{
|
||||
|
|
|
@ -117,7 +117,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
|
|||
}
|
||||
} else {
|
||||
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
|
||||
MS_LOG(ERROR) << "memset_s conv_bias_data failed";
|
||||
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
|
||||
delete[] add_bias_data;
|
||||
return lite::RET_MEMORY_FAILED;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue