forked from mindspore-Ecosystem/mindspore
fix resnext_50 segment bug
This commit is contained in:
parent
7c3d64e0c9
commit
decff01dfe
|
@ -86,15 +86,8 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
|
|||
return post_node_idxes;
|
||||
}
|
||||
|
||||
// only support op_type from current schema
|
||||
bool IsPackedOp(int op_type) {
|
||||
#ifdef ENABLE_V0
|
||||
static std::vector<int> v0_packed_ops = {
|
||||
schema::v0::PrimitiveType_Conv2D, schema::v0::PrimitiveType_DeConv2D, schema::v0::PrimitiveType_DepthwiseConv2D,
|
||||
schema::v0::PrimitiveType_DeDepthwiseConv2D, schema::v0::PrimitiveType_MatMul};
|
||||
if (VersionManager::GetInstance()->CheckV0Schema()) {
|
||||
return IsContain(v0_packed_ops, op_type);
|
||||
}
|
||||
#endif
|
||||
static std::vector<int> packed_ops = {schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion,
|
||||
schema::PrimitiveType_MatMul};
|
||||
return IsContain(packed_ops, op_type);
|
||||
|
|
|
@ -31,13 +31,14 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
using NODE_ID = std::string;
|
||||
|
||||
// only support op_type from current schema
|
||||
bool IsPackedOp(int op_type);
|
||||
|
||||
std::vector<size_t> GetGraphInputNodes(const lite::Model *model);
|
||||
|
||||
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model);
|
||||
|
||||
std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor_idx);
|
||||
|
||||
bool IsPackedOp(int op_type);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -281,24 +281,26 @@ int CopyConstTensorData(const std::vector<Tensor *> &tensors, int op_type) {
|
|||
return RET_OK;
|
||||
}
|
||||
for (auto *tensor : tensors) {
|
||||
// only cast const tensor
|
||||
// tensorlist not support fp16 now
|
||||
if (!tensor->IsConst() || tensor->data_type() == kObjectTypeTensorType) {
|
||||
// only copy non-copied const tensor
|
||||
if (!tensor->IsConst() || tensor->own_data()) {
|
||||
continue;
|
||||
}
|
||||
if (tensor->own_data()) {
|
||||
continue;
|
||||
if (tensor->data_type() == kObjectTypeTensorType) {
|
||||
// tensorlist's data is nullptr since ConvertTensors
|
||||
// we never set or malloc data of tensorlist but malloc tensors in tensorlist
|
||||
MS_ASSERT(tensor->data_c() == nullptr);
|
||||
} else {
|
||||
auto copy_tensor = Tensor::CopyTensor(*tensor, true);
|
||||
if (copy_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Copy tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor->FreeData();
|
||||
tensor->set_data(copy_tensor->data_c());
|
||||
tensor->set_own_data(true);
|
||||
copy_tensor->set_data(nullptr);
|
||||
delete (copy_tensor);
|
||||
}
|
||||
auto copy_tensor = Tensor::CopyTensor(*tensor, true);
|
||||
if (copy_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Copy tensor failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor->FreeData();
|
||||
tensor->set_data(copy_tensor->data_c());
|
||||
tensor->set_own_data(true);
|
||||
copy_tensor->set_data(nullptr);
|
||||
delete (copy_tensor);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
inception_v4_ms_r1.0
|
||||
mobilenet_v2_1.0_224_ms_r1.0
|
||||
# Add input shapes to fix Segmentation fault at Samsung phone
|
||||
resnext50_ms_r1.0;1,224,224,3
|
||||
resnext50_ms_r1.0
|
||||
|
|
Loading…
Reference in New Issue