sync code from inner

This commit is contained in:
mengyuanli 2022-11-07 17:24:41 +08:00
parent f7168c8ee7
commit 67983d767c
13 changed files with 32 additions and 27 deletions

View File

@ -428,10 +428,12 @@ int CastTensorListTensorData(TensorList *dst_tensorlist, TensorList *src_tensorl
}
if (!dst_tensorlist->shape().empty()) {
if (src_tensorlist->tensors_data_type() == kNumberTypeFloat16) {
dst_tensorlist->MallocTensorListData(kNumberTypeFloat32, tensors_shapes);
auto ret = dst_tensorlist->MallocTensorListData(kNumberTypeFloat32, tensors_shapes);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "dst_tensorlist MallocTensorListData failed.");
}
if (src_tensorlist->tensors_data_type() == kNumberTypeFloat32) {
dst_tensorlist->MallocTensorListData(kNumberTypeFloat16, tensors_shapes);
auto ret = dst_tensorlist->MallocTensorListData(kNumberTypeFloat16, tensors_shapes);
MS_CHECK_FALSE_MSG(ret != RET_OK, ret, "dst_tensorlist MallocTensorListData failed.");
}
}
dst_tensorlist->set_allocator(src_tensorlist->allocator());
@ -591,7 +593,8 @@ TensorList *MallocTensorListDataAccordingToTensorListC(Tensor *tensor, TensorLis
auto tensor_shape = std::vector<std::vector<int>>(
tensor_list_c->element_num_, std::vector<int>(tensor_list_c->element_shape_,
tensor_list_c->element_shape_ + tensor_list_c->element_shape_size_));
tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape);
auto ret = tensor_list->MallocTensorListData(static_cast<TypeId>(tensor_list_c->data_type_), tensor_shape);
MS_CHECK_FALSE_MSG(ret != RET_OK, nullptr, "tensor list MallocTensorListData");
return tensor_list;
}

View File

@ -294,11 +294,11 @@ void LiteSwitchOpActor::RunOpData(OpData<Tensor> *inputs, OpContext<Tensor> *con
ret = RunKernel(*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_before_)),
*(reinterpret_cast<const KernelCallBack *>(context->kernel_call_back_after_)));
if (ret != RET_OK) {
input_op_datas_.erase(op_uuid);
(void)input_op_datas_.erase(op_uuid);
context->SetFailed(ret);
return;
}
input_op_datas_.erase(op_uuid);
(void)input_op_datas_.erase(op_uuid);
auto cond_ptr = reinterpret_cast<bool *>(switch_type_node_->in_tensors()[kSwitchCondTensorIndex]->data());
if (cond_ptr == nullptr) {

View File

@ -29,7 +29,6 @@ class LiteSwitchOpActor : public LiteOpActor {
public:
explicit LiteSwitchOpActor(kernel::KernelExec *kernel, lite::InnerContext *ctx) : LiteOpActor(kernel, ctx) {}
~LiteSwitchOpActor() override {
delete call_node_;
delete switch_type_node_;
for (auto &partial_node : partial_nodes_) {
delete partial_node;
@ -67,7 +66,6 @@ class LiteSwitchOpActor : public LiteOpActor {
std::vector<kernel::KernelExec *> partial_nodes_{};
kernel::KernelExec *switch_type_node_ = nullptr;
kernel::KernelExec *call_node_ = nullptr;
// each element is a set of output data which is going to be send to the next target actor.
std::vector<std::vector<OpDataPtr<Tensor>>> all_branchs_output_data_;

View File

@ -150,7 +150,7 @@ int ControlFlowScheduler::SplitSubGraphNodesIntoTwoParts(kernel::SubGraphKernel
auto nodes = subgraph_kernel->nodes();
// get the position of the last non-tail call op.
auto is_non_tail_call = [](kernel::KernelExec *node) { return kernel::KernelExecUtil::IsNonTailCall(node); };
auto is_non_tail_call = [](const kernel::KernelExec *node) { return kernel::KernelExecUtil::IsNonTailCall(node); };
auto last_non_tail_call_iter = std::find_if(nodes.rbegin(), nodes.rend(), is_non_tail_call);
auto distance = nodes.rend() - last_non_tail_call_iter;
if (distance == 0) {
@ -315,7 +315,8 @@ kernel::SubGraphKernel *ControlFlowScheduler::CreateEntranceSubGraph(kernel::Sub
}
src_tensors_->push_back(new_tensor);
new_input_tensors.push_back(new_tensor);
kernel::KernelExecUtil::ReplaceSubGraphNodesInTensor(subgraph, old_tensor, new_tensor);
auto ret = kernel::KernelExecUtil::ReplaceSubGraphNodesInTensor(subgraph, old_tensor, new_tensor);
MS_CHECK_FALSE_MSG(ret != RET_OK, nullptr, "ReplaceSubGraphNodesInTensor failed.");
subgraph->set_in_tensor(new_tensor, i);
}
auto entrance_subgraph = kernel::KernelExecUtil::CreateSubGraphKernel(

View File

@ -18,7 +18,7 @@
#include "src/tensor.h"
namespace mindspore::kernel {
int EntranceSubGraphKernel::Execute(const KernelCallBack &before, const KernelCallBack &after) { return lite::RET_OK; }
int EntranceSubGraphKernel::Execute(const KernelCallBack &, const KernelCallBack &) { return lite::RET_OK; }
SubGraphKernel *EntranceSubGraphKernel::Create(Kernel *kernel) {
auto sub_kernel = new kernel::EntranceSubGraphKernel(kernel);

View File

@ -47,9 +47,6 @@ class EntranceSubGraphKernel : public SubGraphKernel {
int Execute(const KernelCallBack &before, const KernelCallBack &after) override;
int ReSize() override { return RET_OK; };
protected:
int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_CONTROL_FLOW_KERNEL_ENTRANCE_SUBGRAPH_KERNEL_H_

View File

@ -39,7 +39,7 @@ int TensorListSetItemCPUKernel::Prepare() {
return RET_OK;
}
int TensorListSetItemCPUKernel::CheckParam() {
int TensorListSetItemCPUKernel::CheckParam() const {
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type() << " must be int";
return RET_ERROR;

View File

@ -37,7 +37,7 @@ class TensorListSetItemCPUKernel : public LiteKernel {
int IncrementOutputSize(int origin_size);
private:
int CheckParam();
int CheckParam() const;
lite::TensorList *input0_ = nullptr;
lite::Tensor *input2_ = nullptr;
lite::TensorList *output0_ = nullptr;

View File

@ -499,7 +499,8 @@ bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) {
return false;
}
auto nodes = subgraph_kernel->nodes();
return std::any_of(nodes.begin(), nodes.end(), [](KernelExec *node) { return KernelExecUtil::IsNonTailCall(node); });
return std::any_of(nodes.begin(), nodes.end(),
[](const KernelExec *node) { return KernelExecUtil::IsNonTailCall(node); });
}
bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) {
@ -511,13 +512,13 @@ bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) {
return false;
}
auto output_nodes = subgraph_kernel->out_nodes();
if (std::any_of(output_nodes.begin(), output_nodes.end(), [](KernelExec *node) { return IsTailCall(node); })) {
if (std::any_of(output_nodes.begin(), output_nodes.end(), [](const KernelExec *node) { return IsTailCall(node); })) {
return true;
}
return false;
}
std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(KernelExec *call_node) {
std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) {
if (call_node->type() != schema::PrimitiveType_Call) {
MS_LOG(ERROR) << "input node is not call node.";
return {};
@ -628,7 +629,7 @@ bool KernelExecUtil::IsNonTailCallSubGraph(KernelExec *kernel) { return false; }
bool KernelExecUtil::IsTailCallSubGraph(KernelExec *kernel) { return false; }
std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(KernelExec *call_node) { return {}; }
std::vector<KernelExec *> KernelExecUtil::GetCallInputPartials(const KernelExec *call_node) { return {}; }
std::vector<KernelExec *> KernelExecUtil::GetCallInputPartialsCorrespondingOutputSubgraph(KernelExec *call_node) {
return {};

View File

@ -35,7 +35,7 @@ class KernelExecUtil {
static bool IsSwitchTypeCall(KernelExec *kernel);
static bool IsNonTailCall(const KernelExec *node);
static bool IsTailCall(const KernelExec *node);
static std::vector<KernelExec *> GetCallInputPartials(KernelExec *call_node);
static std::vector<KernelExec *> GetCallInputPartials(const KernelExec *call_node);
static KernelExec *GetPartialOutputCall(const KernelExec *partial_node);
static bool IsNonTailCallSubGraph(KernelExec *kernel);
static bool IsTailCallSubGraph(KernelExec *kernel);

View File

@ -97,6 +97,10 @@ class LiteOpActor : public OpActor<lite::Tensor> {
std::unordered_map<Tensor *, Tensor *> *isolate_input_map_ = nullptr; /* real obj in session */
lite::InnerContext *ctx_ = nullptr;
kernel::KernelExec *partial_node_ = nullptr;
kernel::KernelExec *call_node_ = nullptr;
bool support_fp16_ = false;
private:
int CreateCommonArrow(const std::unordered_map<void *, std::set<std::pair<AID, size_t>>> &receivers_map,
const std::set<void *> &receiver_tensors, const size_t &output_index,
@ -106,11 +110,6 @@ class LiteOpActor : public OpActor<lite::Tensor> {
const std::unordered_map<AID, std::set<size_t>> &receiver_index_set);
void MarkArrowAsCompiled(const AID *actor_name, size_t to_index,
std::unordered_map<AID, std::set<size_t>> *receiver_index_set);
private:
kernel::KernelExec *partial_node_ = nullptr;
kernel::KernelExec *call_node_ = nullptr;
bool support_fp16_ = false;
};
int MindrtInit();

View File

@ -263,7 +263,8 @@ int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<
return RET_NOT_SUPPORT;
}
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim(), make_tuple_op_value_input};
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
(void)std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(),
std::back_inserter(partial_cnode_inputs));
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "Failed to create C node.");
partial_cnode->set_fullname_with_scope("partial_" + make_tuple_op->fullname_with_scope() + "_" +

View File

@ -48,10 +48,12 @@ bool DeleteDirRecursively(const std::string &dir_name) {
auto real_file_path = RealPath(file_path.c_str());
auto result = unlink(real_file_path.c_str());
if (result != 0) {
closedir(dir);
MS_LOG(ERROR) << "Delete the file(" << real_file_path << ") failed." << ErrnoToString(errno);
return false;
}
}
closedir(dir);
return true;
}
} // namespace
@ -281,7 +283,10 @@ int MindIRSerializer::SplitSave() {
std::string external_local = model_name_ + "_data_" + std::to_string(index);
auto external_local_path = CreateExternalPath(external_local);
if (fs_->FileExist(external_local_path)) {
fs_->DeleteFile(external_local_path);
if (!fs_->DeleteFile(external_local_path)) {
MS_LOG(ERROR) << "delete file failed.";
return RET_ERROR;
}
}
int64_t parameter_size = 0;
int64_t offset = OFFSET;