!26676 MSLITE][DEVELOP] delete memory copy in infershape pass
Merge pull request !26676 from yangruoqi713/master
This commit is contained in:
commit
7ebfbb0278
|
@ -236,7 +236,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
|
|||
lite::DataInfo data_info;
|
||||
auto param_node = input->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
|
||||
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info) != RET_OK) {
|
||||
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
|
||||
MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -733,8 +733,8 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
|
|||
return RET_OK;
|
||||
}
|
||||
DataInfo data_info;
|
||||
if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info) !=
|
||||
RET_OK) {
|
||||
if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info,
|
||||
true) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse const node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -762,7 +762,8 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons
|
|||
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||
schema::CNodeT *op_node) {
|
||||
DataInfo data_info;
|
||||
auto status = FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info);
|
||||
auto status =
|
||||
FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true);
|
||||
if (status == RET_NO_CHANGE) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -117,7 +117,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh
|
|||
}
|
||||
|
||||
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
|
||||
bool train_flag, DataInfo *data_info) {
|
||||
bool train_flag, DataInfo *data_info, bool copy_data) {
|
||||
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
||||
auto valueAbstract = value_node->abstract();
|
||||
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
|
||||
|
@ -139,16 +139,20 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
|
|||
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
|
||||
data_info->data_.resize(data->Size());
|
||||
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// process weight tensor
|
||||
if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
return RET_ERROR;
|
||||
if (copy_data) {
|
||||
data_info->data_.resize(data->Size());
|
||||
if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
data_info->data_ptr_ = data->data_c();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -236,7 +240,8 @@ int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &p
|
|||
}
|
||||
} // namespace
|
||||
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info,
|
||||
bool copy_data) {
|
||||
MS_ASSERT(param_node != nullptr && data_info != nullptr);
|
||||
ShapeVector shape_vector;
|
||||
TypeId data_type = kTypeUnknown;
|
||||
|
@ -258,7 +263,8 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy
|
|||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
data_info->shape_ = dims;
|
||||
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
||||
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
||||
// tensor_list tensor
|
||||
if (data_type == kObjectTypeTensorType && tensor_info->Size() >= kTensorListMinSize) {
|
||||
data_info->data_.resize(tensor_info->Size() - offset);
|
||||
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
|
||||
|
@ -267,6 +273,20 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
// common node with const data
|
||||
if (data_type != kObjectTypeTensorType) {
|
||||
if (copy_data) {
|
||||
data_info->data_.resize(tensor_info->Size() - offset);
|
||||
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
|
||||
tensor_info->Size() - offset)) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
data_info->data_ptr_ = static_cast<uint8_t *>(tensor_info->data_c()) + offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data_info->format_ = NHWC;
|
||||
|
@ -274,11 +294,11 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy
|
|||
}
|
||||
|
||||
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
DataInfo *data_info, bool copy_data) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
auto param_node = cnode->input(index)->cast<ParameterPtr>();
|
||||
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
|
||||
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
|
||||
if (FetchFromDefaultParam(param_node, fmk_type, data_info, copy_data) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -304,7 +324,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
}
|
||||
|
||||
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
DataInfo *data_info, bool copy_data) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
|
||||
|
@ -314,7 +334,7 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
|
|||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
|
||||
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info, copy_data);
|
||||
if (index == kNumWeightIndex && prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
|
||||
data_info->format_ = GetValue<int64_t>(prim->GetAttr(mindspore::ops::kFormat));
|
||||
}
|
||||
|
|
|
@ -35,17 +35,22 @@ struct DataInfo {
|
|||
int node_type_;
|
||||
std::vector<int> shape_;
|
||||
std::vector<uint8_t> data_;
|
||||
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0), node_type_{0} {}
|
||||
void *data_ptr_;
|
||||
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0), node_type_{0}, data_ptr_(nullptr) {}
|
||||
};
|
||||
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info);
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info,
|
||||
bool copy_data);
|
||||
|
||||
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info);
|
||||
DataInfo *data_info, bool copy_data);
|
||||
|
||||
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info);
|
||||
DataInfo *data_info, bool copy_data);
|
||||
|
||||
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info);
|
||||
|
||||
int RemoveIfDepend(const CNodePtr &cnode);
|
||||
|
||||
int RemoveIfMakeTuple(const CNodePtr &cnode);
|
||||
|
|
|
@ -90,9 +90,9 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const
|
|||
DataInfo data_info;
|
||||
STATUS status;
|
||||
if (utils::isa<Parameter>(node)) {
|
||||
status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
|
||||
status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true);
|
||||
} else if (utils::isa<ValueNode>(node)) {
|
||||
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
|
||||
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true);
|
||||
} else {
|
||||
status = RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -39,13 +39,13 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
|
|||
DataInfo data_info;
|
||||
auto output_type_node = cast_cnode->input(opt::kInputIndexTwo);
|
||||
if (utils::isa<ParameterPtr>(output_type_node)) {
|
||||
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info) !=
|
||||
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true) !=
|
||||
lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Fetch data from parameter node failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (utils::isa<ValueNodePtr>(output_type_node)) {
|
||||
if (FetchDataFromValueNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info) !=
|
||||
if (FetchDataFromValueNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true) !=
|
||||
lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Fetch data from value node failed.";
|
||||
return false;
|
||||
|
|
|
@ -169,13 +169,13 @@ int ReplaceLstmNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func
|
|||
MS_CHECK_TRUE_MSG(lstm_weight_node != nullptr, RET_ERROR, "lstm_weight_node is nullptr.");
|
||||
lite::DataInfo data_info;
|
||||
if (lstm_weight_node->isa<Parameter>()) {
|
||||
auto ret = FetchDataFromParameterNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info);
|
||||
auto ret = FetchDataFromParameterNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info, true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse const node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (lstm_weight_node->isa<ValueNode>()) {
|
||||
auto ret = FetchDataFromValueNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info);
|
||||
auto ret = FetchDataFromValueNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info, true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse const node failed.";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -164,9 +164,9 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
|
|||
lite::DataInfo data_info;
|
||||
int status;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch transpose perm data failed.";
|
||||
|
|
|
@ -344,7 +344,7 @@ bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
|
|||
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
|
||||
std::vector<TensorPtr> inputs_ptr;
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_) != lite::RET_OK) {
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, true) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "extract input tensor from cnode failed. " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -156,9 +156,11 @@ int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
lite::DataInfo add_bias_info;
|
||||
int status = lite::RET_ERROR;
|
||||
if (add_bias->isa<Parameter>()) {
|
||||
status = lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
|
||||
status =
|
||||
lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info, true);
|
||||
} else if (add_bias->isa<ValueNode>()) {
|
||||
status = lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
|
||||
status =
|
||||
lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info, true);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
|
||||
|
@ -170,11 +172,11 @@ int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
if (conv_cnode->size() > kInputSizeThree) {
|
||||
auto conv_bias = conv_cnode->input(kInputIndexThree);
|
||||
if (conv_bias->isa<Parameter>()) {
|
||||
status =
|
||||
lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
|
||||
status = lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false,
|
||||
&conv_bias_info, true);
|
||||
} else if (conv_bias->isa<ValueNode>()) {
|
||||
status =
|
||||
lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
|
||||
lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info, true);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
|
||||
|
|
|
@ -160,9 +160,9 @@ int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
|||
if (!input_node->has_default()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info, true);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info, true);
|
||||
}
|
||||
if (status != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
|
@ -409,6 +409,33 @@ STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info,
|
||||
const CNodePtr &cnode) {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto status = ConvertTensorToNCOrNH(func_graph, cnode, i, fmk_type_, train_flag_, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = ModifyCNodeFormat(cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = node_infer_shape_.InferShape(cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes) {
|
||||
MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
|
||||
|
@ -452,27 +479,10 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
|
|||
if (IsSpecialType(middle_cnode)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < middle_cnode->size(); ++i) {
|
||||
status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "change op attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
status = node_infer_shape_.InferShape(middle_cnode);
|
||||
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "infer shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
status = HandleGraphSingleNode(func_graph, trans_info, middle_cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Decrease transpose for op failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
@ -515,7 +525,7 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
|
|||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ class DecreaseTransposeAlgo : public Pass {
|
|||
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
|
||||
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
STATUS HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info, const CNodePtr &cnode);
|
||||
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
std::set<CNodePtr> *visit_transposes);
|
||||
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);
|
||||
|
|
|
@ -268,7 +268,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
|
|||
}
|
||||
continue;
|
||||
}
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
|
||||
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, false);
|
||||
if (status != lite::RET_OK) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -115,6 +115,11 @@ int ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vect
|
|||
tensor->set_data(tensor_data);
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_size == 0 && data_info.data_ptr_ != nullptr) {
|
||||
tensor->set_data(data_info.data_ptr_);
|
||||
tensor->set_own_data(false);
|
||||
}
|
||||
tensors->emplace_back(tensor);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
|
@ -142,7 +147,7 @@ TensorPtr GetCNodeTensorListVarInput(const lite::DataInfo &data_info) {
|
|||
}
|
||||
|
||||
int GetCNodeConstInput(const CNodePtr &cnode, std::vector<TensorPtr> *const_ms_inputs, converter::FmkType fmk_type,
|
||||
bool train_flag) {
|
||||
bool train_flag, bool copy_data) {
|
||||
MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr);
|
||||
std::vector<lite::DataInfo> data_infos;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
|
@ -152,9 +157,9 @@ int GetCNodeConstInput(const CNodePtr &cnode, std::vector<TensorPtr> *const_ms_i
|
|||
STATUS status;
|
||||
lite::DataInfo data_info;
|
||||
if (utils::isa<ParameterPtr>(cnode->input(i))) {
|
||||
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, train_flag, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, train_flag, &data_info, copy_data);
|
||||
} else {
|
||||
status = lite::FetchDataFromValueNode(cnode, i, fmk_type, train_flag, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, i, fmk_type, train_flag, &data_info, copy_data);
|
||||
}
|
||||
if (status == lite::RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
@ -213,7 +218,7 @@ int GetCNodeVarInput(const CNodePtr &cnode, std::vector<TensorPtr> *var_ms_input
|
|||
} // namespace
|
||||
|
||||
int LiteTensorExtractor::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *inputs,
|
||||
converter::FmkType fmk_type, bool train_flag) {
|
||||
converter::FmkType fmk_type, bool train_flag, bool copy_data) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
MS_ASSERT(inputs != nullptr);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
|
@ -221,7 +226,7 @@ int LiteTensorExtractor::GetCNodeInputTensors(const CNodePtr &cnode, std::vector
|
|||
lite::RemoveIfMakeTuple(cnode);
|
||||
RemoveIfMonad(cnode);
|
||||
std::vector<TensorPtr> const_inputs;
|
||||
if (GetCNodeConstInput(cnode, &const_inputs, fmk_type, train_flag) != lite::RET_OK) {
|
||||
if (GetCNodeConstInput(cnode, &const_inputs, fmk_type, train_flag, copy_data) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get const inputs failed.";
|
||||
cnode->set_inputs(origin_inputs);
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -30,7 +30,7 @@ class LiteTensorExtractor {
|
|||
LiteTensorExtractor() = default;
|
||||
~LiteTensorExtractor() = default;
|
||||
static int GetCNodeInputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *inputs, converter::FmkType fmk_type,
|
||||
bool train_flag);
|
||||
bool train_flag, bool copy_data);
|
||||
static int GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *outputs, bool train_flag);
|
||||
};
|
||||
} // namespace opt
|
||||
|
|
|
@ -86,7 +86,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
|
|||
}
|
||||
anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
|
||||
std::vector<TensorPtr> inputs_ptr;
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_) != lite::RET_OK) {
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, false) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get inputs failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -172,9 +172,9 @@ std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t ind
|
|||
if (utils::isa<CNode>(base_node->input(position))) {
|
||||
status = lite::FetchDataFromCNode(base_node, position, fmk_type_, train_flag_, &data_info);
|
||||
} else if (utils::isa<Parameter>(base_node->input(position))) {
|
||||
status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
|
||||
} else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
|
||||
status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info);
|
||||
status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "input node is invalid.";
|
||||
return {};
|
||||
|
@ -195,7 +195,8 @@ std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t in
|
|||
std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
|
||||
cnode->set_inputs(specify_inputs);
|
||||
std::vector<TensorPtr> specify_tensors;
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_) != lite::RET_OK ||
|
||||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_, false) !=
|
||||
lite::RET_OK ||
|
||||
specify_tensors.empty()) {
|
||||
cnode->set_inputs(origin_inputs);
|
||||
return {};
|
||||
|
|
|
@ -270,13 +270,13 @@ int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite
|
|||
auto padding_node = cnode->input(kInputIndexTwo);
|
||||
MS_ASSERT(padding_node != nullptr);
|
||||
if (utils::isa<Parameter>(padding_node)) {
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
|
||||
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info, true);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from parameter node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (utils::isa<ValueNode>(padding_node)) {
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
|
||||
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info, true);
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "fetch data from value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
|
@ -212,9 +212,9 @@ STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Format
|
|||
lite::DataInfo data_info;
|
||||
int status;
|
||||
if (utils::isa<Parameter>(second_input)) {
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
|
||||
} else if (utils::isa<ValueNode>(second_input)) {
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
|
||||
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
|
||||
} else {
|
||||
return lite::RET_NOT_SUPPORT;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue