!18009 [MS][LITE]mindRT support control flow

Merge pull request !18009 from mengyuanli/add_control_flow
This commit is contained in:
i-robot 2021-06-08 22:12:58 +08:00 committed by Gitee
commit 515e46e4b7
5 changed files with 57 additions and 62 deletions

View File

@ -110,40 +110,6 @@ STATUS GetDataTypeAndShape(const ParameterPtr &param_node, TypeId *data_type, Sh
return RET_OK;
}
int FetchFromDefaultParam(const ParameterPtr &param_node, DataInfo *data_info) {
MS_ASSERT(param_node != nullptr && data_info != nullptr);
ShapeVector shape_vector;
TypeId data_type;
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR;
}
data_info->data_type_ = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && data_type == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
bool train_flag, DataInfo *data_info) {
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
@ -260,6 +226,50 @@ int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &p
}
} // namespace
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
MS_ASSERT(param_node != nullptr && data_info != nullptr);
ShapeVector shape_vector;
TypeId data_type;
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
if (status != RET_OK) {
MS_LOG(ERROR) << "get data type and shape from param node failed.";
return RET_ERROR;
}
data_info->data_type_ = data_type;
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
size_t offset = 0;
if (!shape_vector.empty() && data_type == kObjectTypeString) {
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
if (status != RET_OK) {
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
return RET_ERROR;
}
}
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
data_info->shape_ = dims;
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}
}
}
data_info->format_ = GetFormatByFmk(fmk_type);
if (data_info->format_ < 0) {
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
return RET_ERROR;
}
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
return RET_ERROR;
}
return RET_OK;
}
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info) {
MS_ASSERT(cnode != nullptr && data_info != nullptr);
@ -268,13 +278,9 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
MS_LOG(ERROR) << "input node is not parameter node.";
return RET_ERROR;
}
data_info->format_ = GetFormatByFmk(fmk_type);
if (data_info->format_ < 0) {
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
return RET_ERROR;
}
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
MS_LOG(ERROR) << "fetch information from default param failed.";
return RET_ERROR;
}
@ -286,10 +292,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
data_info->format_ = mindspore::KHWC;
}
if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
MS_LOG(ERROR) << "fetch information from default param failed.";
return RET_ERROR;
}
QuantParamHolderPtr quant_param_holder =
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&

View File

@ -35,6 +35,9 @@ struct DataInfo {
std::vector<uint8_t> data_;
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0) {}
};
int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkType &fmk_type, DataInfo *data_info);
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
DataInfo *data_info);
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,

View File

@ -49,18 +49,6 @@ class AnfTransform {
static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config);
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
static int RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);

View File

@ -625,8 +625,8 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
sub_inputs_map_[sub_graph] = {};
auto sub_inputs = sub_graph->get_inputs();
sub_inputs_map_[sub_graph] = sub_inputs;
for (auto &node : sub_inputs) {
auto param_node = node->cast<ParameterPtr>();
MS_ASSERT(param_node != nullptr);
@ -649,7 +649,6 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
if (utils::isa<ParameterPtr>(cnode->input(index))) {
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
sub_inputs_map_[sub_graph].push_back(param_node);
}
continue;
}
@ -664,7 +663,6 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
data_info.data_.data(), data_info.data_.size()));
}
sub_inputs_map_[sub_graph].push_back(param_node);
}
}
}
@ -679,8 +677,11 @@ void UnifyFormatPass::ResetSubGraphInput() {
auto param_node = sub_graph->add_parameter();
MS_ASSERT(param_node != nullptr);
param_node->set_abstract(sub_input->abstract()->Clone());
param_node->set_name(sub_input->name());
param_node->set_name(sub_input->fullname_with_scope());
manager->Replace(sub_input, param_node);
auto sub_param_input = sub_input->cast<ParameterPtr>();
MS_ASSERT(sub_param_input != nullptr);
sub_param_input->set_default_param(nullptr);
}
}
}

View File

@ -73,7 +73,7 @@ class UnifyFormatPass : public Pass {
TransposeStrategy transpose_strategy_;
std::set<AnfNodePtr> pre_insert_trans_;
std::set<AnfNodePtr> post_insert_trans_;
std::unordered_map<FuncGraphPtr, std::vector<ParameterPtr>> sub_inputs_map_;
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
};
} // namespace opt
} // namespace mindspore