forked from mindspore-Ecosystem/mindspore
!24238 [MSLITE] Fix bug of constantOfShape.
Merge pull request !24238 from wangshaocong/bugfix_issue
This commit is contained in:
commit
989a640308
|
@ -28,8 +28,8 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/version_manager.h"
|
||||
#include "src/cpu_info.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include "src/common/utils.h"
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "nnacl/constant_of_shape_parameter.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -204,6 +204,20 @@ class CpuFp16SubGraph : public CpuSubGraph {
|
|||
if (output->data_type() == kNumberTypeFloat32) {
|
||||
output->set_data_type(kNumberTypeFloat16);
|
||||
}
|
||||
} else if (node->type() == schema::PrimitiveType_ConstantOfShape) {
|
||||
auto param = node->op_parameter();
|
||||
MS_ASSERT(param != nullptr);
|
||||
if (static_cast<TypeId>(reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ ==
|
||||
kNumberTypeFloat32)) {
|
||||
reinterpret_cast<ConstantOfShapeParameter *>(param)->data_type_ = kNumberTypeFloat16;
|
||||
}
|
||||
auto outputs = node->out_tensors();
|
||||
MS_ASSERT(outputs.size() == 1);
|
||||
auto output = outputs.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (output->data_type() == kNumberTypeFloat32) {
|
||||
output->set_data_type(kNumberTypeFloat16);
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue