!24238 [MSLITE] Fix bug of constantOfShape.

Merge pull request !24238 from wangshaocong/bugfix_issue
This commit is contained in:
i-robot 2021-09-30 09:09:09 +00:00 committed by Gitee
commit 989a640308
1 changed files with 16 additions and 2 deletions

View File

@ -28,8 +28,8 @@
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "src/cpu_info.h"
#ifdef ENABLE_ARM64
#include "src/common/utils.h"
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
#include "nnacl/constant_of_shape_parameter.h"
#endif
namespace mindspore::kernel {
@ -204,6 +204,20 @@ class CpuFp16SubGraph : public CpuSubGraph {
if (output->data_type() == kNumberTypeFloat32) {
output->set_data_type(kNumberTypeFloat16);
}
} else if (node->type() == schema::PrimitiveType_ConstantOfShape) {
auto param = node->op_parameter();
MS_ASSERT(param != nullptr);
if (static_cast<TypeId>(reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ ==
kNumberTypeFloat32)) {
reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ = kNumberTypeFloat16;
}
auto outputs = node->out_tensors();
MS_ASSERT(outputs.size() == 1);
auto output = outputs.front();
MS_ASSERT(output != nullptr);
if (output->data_type() == kNumberTypeFloat32) {
output->set_data_type(kNumberTypeFloat16);
}
}
}
return RET_OK;