forked from mindspore-Ecosystem/mindspore
!7121 [MSLITE] Fix bug of onxx cast parser.
Merge pull request !7121 from wangshaocong/bugfix_master
This commit is contained in:
commit
79b974eb82
|
@ -95,7 +95,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input->data_type() != GetSrcT()) {
|
if (GetSrcT() != 0 && input->data_type() != GetSrcT()) {
|
||||||
MS_LOG(ERROR) << "input dataType is error";
|
MS_LOG(ERROR) << "input dataType is error";
|
||||||
return RET_INPUT_TENSOR_ERROR;
|
return RET_INPUT_TENSOR_ERROR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,6 +131,7 @@
|
||||||
#include "src/ops/custom_predict.h"
|
#include "src/ops/custom_predict.h"
|
||||||
#include "src/ops/custom_normalize.h"
|
#include "src/ops/custom_normalize.h"
|
||||||
#include "src/ops/custom_extract_features.h"
|
#include "src/ops/custom_extract_features.h"
|
||||||
|
#include "src/ops/upsample.h"
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
#include "tools/converter/quantizer/quantize_util.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
||||||
return new CustomNormalize(primitive);
|
return new CustomNormalize(primitive);
|
||||||
case schema::PrimitiveType_CustomExtractFeatures:
|
case schema::PrimitiveType_CustomExtractFeatures:
|
||||||
return new CustomExtractFeatures(primitive);
|
return new CustomExtractFeatures(primitive);
|
||||||
|
case schema::PrimitiveType_Upsample:
|
||||||
|
return new Upsample(primitive);
|
||||||
|
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
case schema::PrimitiveType_ActivationGrad:
|
case schema::PrimitiveType_ActivationGrad:
|
||||||
|
@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
|
||||||
return NewPrimitiveC<CustomNormalize>(primitive);
|
return NewPrimitiveC<CustomNormalize>(primitive);
|
||||||
case schema::PrimitiveType_CustomExtractFeatures:
|
case schema::PrimitiveType_CustomExtractFeatures:
|
||||||
return NewPrimitiveC<CustomExtractFeatures>(primitive);
|
return NewPrimitiveC<CustomExtractFeatures>(primitive);
|
||||||
|
case schema::PrimitiveType_Upsample:
|
||||||
|
return NewPrimitiveC<Upsample>(primitive);
|
||||||
|
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
case schema::PrimitiveType_ActivationGrad:
|
case schema::PrimitiveType_ActivationGrad:
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
|
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
|
||||||
|
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
||||||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||||
const auto &attribute_name = onnx_node_attr.name();
|
const auto &attribute_name = onnx_node_attr.name();
|
||||||
if (attribute_name == "to") {
|
if (attribute_name == "to") {
|
||||||
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
|
attr->dstT = static_cast<int32_t>(
|
||||||
|
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser {
|
||||||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||||
|
|
||||||
private:
|
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||||
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
|
||||||
|
|
||||||
|
private:
|
||||||
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
|
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
|
||||||
|
|
||||||
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
|
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "tools/converter/parser/onnx/onnx_unsample_parser.h"
|
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
Loading…
Reference in New Issue