forked from mindspore-Ecosystem/mindspore
!7121 [MSLITE] Fix bug of onxx cast parser.
Merge pull request !7121 from wangshaocong/bugfix_master
This commit is contained in:
commit
79b974eb82
|
@ -95,7 +95,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
if (input->data_type() != GetSrcT()) {
|
||||
if (GetSrcT() != 0 && input->data_type() != GetSrcT()) {
|
||||
MS_LOG(ERROR) << "input dataType is error";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
|
|
@ -131,6 +131,7 @@
|
|||
#include "src/ops/custom_predict.h"
|
||||
#include "src/ops/custom_normalize.h"
|
||||
#include "src/ops/custom_extract_features.h"
|
||||
#include "src/ops/upsample.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#endif
|
||||
|
@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new CustomNormalize(primitive);
|
||||
case schema::PrimitiveType_CustomExtractFeatures:
|
||||
return new CustomExtractFeatures(primitive);
|
||||
case schema::PrimitiveType_Upsample:
|
||||
return new Upsample(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
|
|||
return NewPrimitiveC<CustomNormalize>(primitive);
|
||||
case schema::PrimitiveType_CustomExtractFeatures:
|
||||
return NewPrimitiveC<CustomExtractFeatures>(primitive);
|
||||
case schema::PrimitiveType_Upsample:
|
||||
return NewPrimitiveC<Upsample>(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
||||
const auto &attribute_name = onnx_node_attr.name();
|
||||
if (attribute_name == "to") {
|
||||
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
|
||||
attr->dstT = static_cast<int32_t>(
|
||||
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser {
|
|||
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType = QuantType_QUANT_NONE) override;
|
||||
|
||||
private:
|
||||
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||
|
||||
private:
|
||||
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
|
||||
|
||||
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_unsample_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
Loading…
Reference in New Issue