!26676 MSLITE][DEVELOP] delete memory copy in infershape pass

Merge pull request !26676 from yangruoqi713/master
This commit is contained in:
i-robot 2021-12-02 07:51:12 +00:00 committed by Gitee
commit 7ebfbb0278
17 changed files with 117 additions and 72 deletions

View File

@ -236,7 +236,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
lite::DataInfo data_info;
auto param_node = input->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info) != RET_OK) {
if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
return RET_ERROR;
}
@ -733,8 +733,8 @@ int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, cons
return RET_OK;
}
DataInfo data_info;
if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info) !=
RET_OK) {
if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info,
true) != RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;
}
@ -762,7 +762,8 @@ int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, cons
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *op_node) {
DataInfo data_info;
auto status = FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info);
auto status =
FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true);
if (status == RET_NO_CHANGE) {
return RET_OK;
}

View File

@ -117,7 +117,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
}
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
bool train_flag, DataInfo *data_info) {
bool train_flag, DataInfo *data_info, bool copy_data) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
auto valueAbstract = value_node->abstract();
MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr");
@ -139,16 +139,20 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri
MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr");
auto data = value->cast<tensor::TensorPtr>();
MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid");
data_info->data_.resize(data->Size());
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
return RET_ERROR;
}
// process weight tensor
if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
if (copy_data) {
data_info->data_.resize(data->Size());
if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) {
MS_LOG(ERROR) << "memcpy_s error.";
return RET_ERROR;
}
} else {
data_info->data_ptr_ = data->data_c();
}
return RET_OK;
}
@ -236,7 +240,8 @@ int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &p
}
} // namespace
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info,
bool copy_data) {
MS_ASSERT(param_node != nullptr && data_info != nullptr);
ShapeVector shape_vector;
TypeId data_type = kTypeUnknown;
@ -258,7 +263,8 @@ int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkTy
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
// tensor_list tensor
if (data_type == kObjectTypeTensorType && tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
@ -267,6 +273,20 @@ int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkTy
return RET_ERROR;
}
}
// common node with const data
if (data_type != kObjectTypeTensorType) {
if (copy_data) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
} else {
data_info->data_ptr_ = static_cast<uint8_t *>(tensor_info->data_c()) + offset;
}
}
}
data_info->format_ = NHWC;
@ -274,11 +294,11 @@ int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkTy
}
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
DataInfo *data_info, bool copy_data) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto param_node = cnode->input(index)->cast<ParameterPtr>();
MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node.");
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
if (FetchFromDefaultParam(param_node, fmk_type, data_info, copy_data) != RET_OK) {
MS_LOG(ERROR) << "fetch information from default param failed.";
return RET_ERROR;
}
@ -304,7 +324,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
}
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
DataInfo *data_info, bool copy_data) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
auto value_node = cnode->input(index)->cast<ValueNodePtr>();
MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node.");
@ -314,7 +334,7 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "prim is nullptr");
if (value->isa<tensor::Tensor>()) {
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info);
ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info, copy_data);
if (index == kNumWeightIndex && prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
data_info->format_ = GetValue<int64_t>(prim->GetAttr(mindspore::ops::kFormat));
}

View File

@ -35,17 +35,22 @@ struct DataInfo {
int node_type_;
std::vector<int> shape_;
std::vector<uint8_t> data_;
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0), node_type_{0} {}
void *data_ptr_;
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0), node_type_{0}, data_ptr_(nullptr) {}
};
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info);
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info,
bool copy_data);
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
DataInfo *data_info, bool copy_data);
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
DataInfo *data_info, bool copy_data);
int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
int RemoveIfDepend(const CNodePtr &cnode);
int RemoveIfMakeTuple(const CNodePtr &cnode);

View File

@ -90,9 +90,9 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const
DataInfo data_info;
STATUS status;
if (utils::isa<Parameter>(node)) {
status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true);
} else if (utils::isa<ValueNode>(node)) {
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true);
} else {
status = RET_ERROR;
}

View File

@ -39,13 +39,13 @@ bool GetInOutDataTypeValue(const CNodePtr &cast_cnode, int *output_type_value, i
DataInfo data_info;
auto output_type_node = cast_cnode->input(opt::kInputIndexTwo);
if (utils::isa<ParameterPtr>(output_type_node)) {
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info) !=
if (FetchDataFromParameterNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true) !=
lite::RET_OK) {
MS_LOG(ERROR) << "Fetch data from parameter node failed.";
return false;
}
} else if (utils::isa<ValueNodePtr>(output_type_node)) {
if (FetchDataFromValueNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info) !=
if (FetchDataFromValueNode(cast_cnode, opt::kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true) !=
lite::RET_OK) {
MS_LOG(ERROR) << "Fetch data from value node failed.";
return false;

View File

@ -169,13 +169,13 @@ int ReplaceLstmNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func
MS_CHECK_TRUE_MSG(lstm_weight_node != nullptr, RET_ERROR, "lstm_weight_node is nullptr.");
lite::DataInfo data_info;
if (lstm_weight_node->isa<Parameter>()) {
auto ret = FetchDataFromParameterNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info);
auto ret = FetchDataFromParameterNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;
}
} else if (lstm_weight_node->isa<ValueNode>()) {
auto ret = FetchDataFromValueNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info);
auto ret = FetchDataFromValueNode(lstm_cnode, kLSTMWeightIndex, converter::kFmkTypeMs, false, &data_info, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse const node failed.";
return RET_ERROR;

View File

@ -164,9 +164,9 @@ STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) {
lite::DataInfo data_info;
int status;
if (utils::isa<ParameterPtr>(cnode->input(kInputIndexTwo))) {
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
} else {
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "fetch transpose perm data failed.";

View File

@ -344,7 +344,7 @@ bool ConstFoldPass::CheckCanSpecialFold(const CNodePtr &cnode) const {
int ConstFoldPass::DoConstantFold(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
MS_ASSERT(func_graph != nullptr && cnode != nullptr);
std::vector<TensorPtr> inputs_ptr;
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_) != lite::RET_OK) {
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, true) != lite::RET_OK) {
MS_LOG(ERROR) << "extract input tensor from cnode failed. " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}

View File

@ -156,9 +156,11 @@ int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr
lite::DataInfo add_bias_info;
int status = lite::RET_ERROR;
if (add_bias->isa<Parameter>()) {
status = lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
status =
lite::FetchDataFromParameterNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info, true);
} else if (add_bias->isa<ValueNode>()) {
status = lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info);
status =
lite::FetchDataFromValueNode(add_cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &add_bias_info, true);
}
if (status != lite::RET_OK) {
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";
@ -170,11 +172,11 @@ int ConvBiasaddFusion::DoFuison(const FuncGraphPtr &func_graph, const AnfNodePtr
if (conv_cnode->size() > kInputSizeThree) {
auto conv_bias = conv_cnode->input(kInputIndexThree);
if (conv_bias->isa<Parameter>()) {
status =
lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
status = lite::FetchDataFromParameterNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false,
&conv_bias_info, true);
} else if (conv_bias->isa<ValueNode>()) {
status =
lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info);
lite::FetchDataFromValueNode(conv_cnode, kInputIndexThree, converter::kFmkTypeMs, false, &conv_bias_info, true);
}
if (status != lite::RET_OK) {
MS_LOG(DEBUG) << "conv and add do fusion failed, please check";

View File

@ -160,9 +160,9 @@ int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
if (!input_node->has_default()) {
return lite::RET_OK;
}
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info);
status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info, true);
} else {
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info);
status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info, true);
}
if (status != lite::RET_OK) {
return lite::RET_ERROR;
@ -409,6 +409,33 @@ STATUS DecreaseTransposeAlgo::InsertPostTransNode(const FuncGraphPtr &func_graph
return lite::RET_OK;
}
STATUS DecreaseTransposeAlgo::HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info,
const CNodePtr &cnode) {
for (size_t i = 1; i < cnode->size(); ++i) {
auto status = ConvertTensorToNCOrNH(func_graph, cnode, i, fmk_type_, train_flag_, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
return lite::RET_ERROR;
}
}
auto status = transpose_strategy_.ChangeOpAxis(func_graph, cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
status = ModifyCNodeFormat(cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
return lite::RET_ERROR;
}
status = node_infer_shape_.InferShape(cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
}
return RET_OK;
}
STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes) {
MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr);
@ -452,27 +479,10 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap
if (IsSpecialType(middle_cnode)) {
continue;
}
for (size_t i = 1; i < middle_cnode->size(); ++i) {
status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed.";
return lite::RET_ERROR;
}
}
status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "change op attr failed.";
return lite::RET_ERROR;
}
status = ModifyCNodeFormat(middle_cnode, trans_info.post_);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "ModifyCNodeFormat failed.";
return lite::RET_ERROR;
}
status = node_infer_shape_.InferShape(middle_cnode);
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) {
MS_LOG(ERROR) << "infer shape failed.";
return lite::RET_ERROR;
status = HandleGraphSingleNode(func_graph, trans_info, middle_cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "Decrease transpose for op failed.";
return status;
}
}
return lite::RET_OK;
@ -515,7 +525,7 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, true);
if (status != lite::RET_OK) {
continue;
}

View File

@ -47,6 +47,7 @@ class DecreaseTransposeAlgo : public Pass {
bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph);
STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS HandleGraphSingleNode(const FuncGraphPtr &func_graph, const TransTypePair &trans_info, const CNodePtr &cnode);
STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
std::set<CNodePtr> *visit_transposes);
STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info);

View File

@ -268,7 +268,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt
}
continue;
}
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info);
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info, false);
if (status != lite::RET_OK) {
continue;
}

View File

@ -115,6 +115,11 @@ int ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vect
tensor->set_data(tensor_data);
}
}
if (tensor_size == 0 && data_info.data_ptr_ != nullptr) {
tensor->set_data(data_info.data_ptr_);
tensor->set_own_data(false);
}
tensors->emplace_back(tensor);
}
return lite::RET_OK;
@ -142,7 +147,7 @@ TensorPtr GetCNodeTensorListVarInput(const lite::DataInfo &data_info) {
}
int GetCNodeConstInput(const CNodePtr &cnode, std::vector<TensorPtr> *const_ms_inputs, converter::FmkType fmk_type,
bool train_flag) {
bool train_flag, bool copy_data) {
MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr);
std::vector<lite::DataInfo> data_infos;
for (size_t i = 1; i < cnode->size(); ++i) {
@ -152,9 +157,9 @@ int GetCNodeConstInput(const CNodePtr &cnode, std::vector<TensorPtr> *const_ms_i
STATUS status;
lite::DataInfo data_info;
if (utils::isa<ParameterPtr>(cnode->input(i))) {
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, train_flag, &data_info);
status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, train_flag, &data_info, copy_data);
} else {
status = lite::FetchDataFromValueNode(cnode, i, fmk_type, train_flag, &data_info);
status = lite::FetchDataFromValueNode(cnode, i, fmk_type, train_flag, &data_info, copy_data);
}
if (status == lite::RET_NO_CHANGE) {
continue;
@ -213,7 +218,7 @@ int GetCNodeVarInput(const CNodePtr &cnode, std::vector<TensorPtr> *var_ms_input
} // namespace
int LiteTensorExtractor::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *inputs,
converter::FmkType fmk_type, bool train_flag) {
converter::FmkType fmk_type, bool train_flag, bool copy_data) {
MS_ASSERT(cnode != nullptr);
MS_ASSERT(inputs != nullptr);
auto origin_inputs = cnode->inputs();
@ -221,7 +226,7 @@ int LiteTensorExtractor::GetCNodeInputTensors(const CNodePtr &cnode, std::vector
lite::RemoveIfMakeTuple(cnode);
RemoveIfMonad(cnode);
std::vector<TensorPtr> const_inputs;
if (GetCNodeConstInput(cnode, &const_inputs, fmk_type, train_flag) != lite::RET_OK) {
if (GetCNodeConstInput(cnode, &const_inputs, fmk_type, train_flag, copy_data) != lite::RET_OK) {
MS_LOG(ERROR) << "get const inputs failed.";
cnode->set_inputs(origin_inputs);
return lite::RET_ERROR;

View File

@ -30,7 +30,7 @@ class LiteTensorExtractor {
LiteTensorExtractor() = default;
~LiteTensorExtractor() = default;
static int GetCNodeInputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *inputs, converter::FmkType fmk_type,
bool train_flag);
bool train_flag, bool copy_data);
static int GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<TensorPtr> *outputs, bool train_flag);
};
} // namespace opt

View File

@ -86,7 +86,7 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
}
anf_prim->AddAttr(kInferDone, MakeValue<bool>(false));
std::vector<TensorPtr> inputs_ptr;
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_) != lite::RET_OK) {
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &inputs_ptr, fmk_type_, train_flag_, false) != lite::RET_OK) {
MS_LOG(ERROR) << "get inputs failed.";
return lite::RET_ERROR;
}
@ -172,9 +172,9 @@ std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t ind
if (utils::isa<CNode>(base_node->input(position))) {
status = lite::FetchDataFromCNode(base_node, position, fmk_type_, train_flag_, &data_info);
} else if (utils::isa<Parameter>(base_node->input(position))) {
status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info);
status = lite::FetchDataFromParameterNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
} else if (utils::isa<ValueNodePtr>(base_node->input(position))) {
status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info);
status = lite::FetchDataFromValueNode(base_node, position, fmk_type_, train_flag_, &data_info, false);
} else {
MS_LOG(ERROR) << "input node is invalid.";
return {};
@ -195,7 +195,8 @@ std::vector<int> NodeInferShape::GetIntVecInput(const CNodePtr &cnode, size_t in
std::vector<AnfNodePtr> specify_inputs = {origin_inputs[0], origin_inputs[index]};
cnode->set_inputs(specify_inputs);
std::vector<TensorPtr> specify_tensors;
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_) != lite::RET_OK ||
if (LiteTensorExtractor::GetCNodeInputTensors(cnode, &specify_tensors, fmk_type_, train_flag_, false) !=
lite::RET_OK ||
specify_tensors.empty()) {
cnode->set_inputs(origin_inputs);
return {};

View File

@ -270,13 +270,13 @@ int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite
auto padding_node = cnode->input(kInputIndexTwo);
MS_ASSERT(padding_node != nullptr);
if (utils::isa<Parameter>(padding_node)) {
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info, true);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from parameter node failed.";
return lite::RET_ERROR;
}
} else if (utils::isa<ValueNode>(padding_node)) {
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info, true);
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "fetch data from value node failed.";
return lite::RET_ERROR;

View File

@ -212,9 +212,9 @@ STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Format
lite::DataInfo data_info;
int status;
if (utils::isa<Parameter>(second_input)) {
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
} else if (utils::isa<ValueNode>(second_input)) {
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
} else {
return lite::RET_NOT_SUPPORT;
}