forked from mindspore-Ecosystem/mindspore
!18009 [MS][LITE]mindRT support control flow
Merge pull request !18009 from mengyuanli/add_control_flow
This commit is contained in:
commit
515e46e4b7
|
@ -110,40 +110,6 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, DataInfo *data_info) {
|
||||
MS_ASSERT(param_node != nullptr && data_info != nullptr);
|
||||
ShapeVector shape_vector;
|
||||
TypeId data_type;
|
||||
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "get data type and shape from param node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_info->data_type_ = data_type;
|
||||
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
|
||||
size_t offset = 0;
|
||||
if (!shape_vector.empty() && data_type == kObjectTypeString) {
|
||||
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
data_info->shape_ = dims;
|
||||
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
||||
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
||||
data_info->data_.resize(tensor_info->Size() - offset);
|
||||
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type,
|
||||
bool train_flag, DataInfo *data_info) {
|
||||
MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr);
|
||||
|
@ -260,6 +226,50 @@ int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &p
|
|||
}
|
||||
} // namespace
|
||||
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info) {
|
||||
MS_ASSERT(param_node != nullptr && data_info != nullptr);
|
||||
ShapeVector shape_vector;
|
||||
TypeId data_type;
|
||||
auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "get data type and shape from param node failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_info->data_type_ = data_type;
|
||||
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
|
||||
size_t offset = 0;
|
||||
if (!shape_vector.empty() && data_type == kObjectTypeString) {
|
||||
status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "get shape vector from string tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end());
|
||||
data_info->shape_ = dims;
|
||||
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
||||
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
||||
data_info->data_.resize(tensor_info->Size() - offset);
|
||||
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data_info->format_ = GetFormatByFmk(fmk_type);
|
||||
if (data_info->format_ < 0) {
|
||||
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info) {
|
||||
MS_ASSERT(cnode != nullptr && data_info != nullptr);
|
||||
|
@ -268,13 +278,9 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
MS_LOG(ERROR) << "input node is not parameter node.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_info->format_ = GetFormatByFmk(fmk_type);
|
||||
if (data_info->format_ < 0) {
|
||||
MS_LOG(ERROR) << "don't support current fmk: " << fmk_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) {
|
||||
MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_;
|
||||
|
||||
if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -286,10 +292,7 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F
|
|||
(index == 2 && prim->GetAttr(ops::kFormat) != nullptr)) {
|
||||
data_info->format_ = mindspore::KHWC;
|
||||
}
|
||||
if (FetchFromDefaultParam(param_node, data_info) != RET_OK) {
|
||||
MS_LOG(ERROR) << "fetch information from default param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
QuantParamHolderPtr quant_param_holder =
|
||||
prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
|
||||
if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() &&
|
||||
|
|
|
@ -35,6 +35,9 @@ struct DataInfo {
|
|||
std::vector<uint8_t> data_;
|
||||
DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0) {}
|
||||
};
|
||||
|
||||
int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkType &fmk_type, DataInfo *data_info);
|
||||
|
||||
int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
DataInfo *data_info);
|
||||
int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag,
|
||||
|
|
|
@ -49,18 +49,6 @@ class AnfTransform {
|
|||
|
||||
static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config);
|
||||
|
||||
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
|
||||
|
||||
static int RunConv1DAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
||||
|
|
|
@ -625,8 +625,8 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con
|
|||
|
||||
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) {
|
||||
MS_ASSERT(cnode != nullptr && sub_graph != nullptr);
|
||||
sub_inputs_map_[sub_graph] = {};
|
||||
auto sub_inputs = sub_graph->get_inputs();
|
||||
sub_inputs_map_[sub_graph] = sub_inputs;
|
||||
for (auto &node : sub_inputs) {
|
||||
auto param_node = node->cast<ParameterPtr>();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
|
@ -649,7 +649,6 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
if (utils::isa<ParameterPtr>(cnode->input(index))) {
|
||||
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) {
|
||||
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param());
|
||||
sub_inputs_map_[sub_graph].push_back(param_node);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -664,7 +663,6 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr
|
|||
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec,
|
||||
data_info.data_.data(), data_info.data_.size()));
|
||||
}
|
||||
sub_inputs_map_[sub_graph].push_back(param_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -679,8 +677,11 @@ void UnifyFormatPass::ResetSubGraphInput() {
|
|||
auto param_node = sub_graph->add_parameter();
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
param_node->set_abstract(sub_input->abstract()->Clone());
|
||||
param_node->set_name(sub_input->name());
|
||||
param_node->set_name(sub_input->fullname_with_scope());
|
||||
manager->Replace(sub_input, param_node);
|
||||
auto sub_param_input = sub_input->cast<ParameterPtr>();
|
||||
MS_ASSERT(sub_param_input != nullptr);
|
||||
sub_param_input->set_default_param(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ class UnifyFormatPass : public Pass {
|
|||
TransposeStrategy transpose_strategy_;
|
||||
std::set<AnfNodePtr> pre_insert_trans_;
|
||||
std::set<AnfNodePtr> post_insert_trans_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<ParameterPtr>> sub_inputs_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<AnfNodePtr>> sub_inputs_map_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue