add fp16 sub_graph preProcess

This commit is contained in:
hangangqiang 2020-10-31 17:17:00 +08:00
parent e675e78b2c
commit 261023d980
38 changed files with 142 additions and 68 deletions

View File

@ -106,7 +106,7 @@ class LogWriter {
} // namespace mindspore
#ifdef DEBUG
#ifdef Debug
#include <cassert>
#define MS_ASSERT(f) assert(f)
#else

View File

@ -43,6 +43,5 @@ class Executor {
int TransformTensorLayout(Tensor *tensor, schema::Format dst_format, Allocator *allocator = nullptr);
};
} // namespace mindspore::lite
#endif

View File

@ -162,7 +162,6 @@ void LiteSession::InitGraphInputMSTensors() {
void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->outputs_.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
auto out_tensor_idx = model->sub_graphs_.front()->output_indices_[i];
@ -181,7 +180,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
for (auto in_node_index : graph_input_node_indexes) {
auto in_node = model->all_nodes_[in_node_index];
MS_ASSERT(in_node != nullptr);
MS_ASSERT(this->input_map_.find(in_node->name()->str()) == this->input_map_.end());
MS_ASSERT(this->input_map_.find(in_node->name_) == this->input_map_.end());
auto in_size = in_node->input_indices_.size();
for (size_t i = 0; i < in_size; ++i) {
auto in_tensor_index = size_t(in_node->input_indices_[i]);
@ -215,7 +214,6 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
for (auto out_node_index : graph_output_node_indexes) {
auto out_node = model->all_nodes_[out_node_index];
MS_ASSERT(out_node != nullptr);
MS_ASSERT(this->output_map_.find(out_node->name()->str()) == this->output_map_.end());
auto out_size = out_node->output_indices_.size();
for (size_t i = 0; i < out_size; ++i) {
auto out_tensor_index = out_node->output_indices_[i];

View File

@ -95,7 +95,6 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
output->SetFormat(input->GetFormat());
MS_ASSERT(cast_prim != nullptr);
output->set_data_type(static_cast<TypeId>(GetDstT()));
if (!GetInferFlag()) {
return RET_OK;

View File

@ -104,7 +104,6 @@ int Concat::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_OK;
}
MS_ASSERT(concat_prim != nullptr);
auto input0_shape = inputs_.at(0)->shape();
auto axis = GetAxis() < 0 ? GetAxis() + input0_shape.size() : GetAxis();
if (axis < 0 || axis >= input0_shape.size()) {

View File

@ -197,7 +197,8 @@ int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tenso
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
MS_ASSERT(in0 != nullptr);
MS_ASSERT(in != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->MutableData());

View File

@ -200,7 +200,8 @@ int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
MS_ASSERT(in0 != nullptr);
MS_ASSERT(in != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->MutableData());

View File

@ -169,8 +169,8 @@ int DetectionPostProcess::InferShape(std::vector<lite::Tensor *> inputs_, std::v
const auto input_anchors_shape = anchors->shape();
MS_ASSERT(input_scores_shape[2] >= GetNumClasses());
MS_ASSERT(input_scores_shape[2] - GetNumClasses() <= 1);
MS_ASSERT(input_box_shape[1] = input_scores_shape[1]);
MS_ASSERT(input_box_shape[1] = input_anchors_shape[0]);
MS_ASSERT(input_box_shape[1] == input_scores_shape[1]);
MS_ASSERT(input_box_shape[1] == input_anchors_shape[0]);
auto detected_boxes = outputs_.at(0);
MS_ASSERT(detected_boxes != nullptr);

View File

@ -113,7 +113,6 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
return RET_OK;
}
MS_ASSERT(gather_prim != nullptr);
int axis = GetAxis();
int batch_dims = GetBatchDims();
if (axis < 0) {

View File

@ -149,7 +149,8 @@ int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<T
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
MS_ASSERT(in0 != nullptr);
MS_ASSERT(in != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->MutableData());

View File

@ -46,7 +46,6 @@ int HashtableLookup::InferShape(std::vector<Tensor *> inputs_, std::vector<Tenso
auto output = outputs_.at(0);
auto hits = outputs_.at(1);
MS_ASSERT(input != nullptr);
MS_ASSERT(keys != nullptr);
MS_ASSERT(values != nullptr);
MS_ASSERT(output != nullptr);
MS_ASSERT(hits != nullptr);

View File

@ -65,7 +65,7 @@ int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
if (inputs_.size() == kMultiNum) {
MS_ASSERT(inputs_.at(2)->shape().size() == 1);
MS_ASSERT(inputs_.at(2)->DimensionSize(0) == in_value->DimensionSize(0));
MS_ASSERT(inputs_.at(2)->DimensionSize(0) == inputs_.at(1)->DimensionSize(0));
}
auto out_tensor = outputs_.front();

View File

@ -60,7 +60,7 @@ int Lstm::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto weight_i = inputs_[1];
MS_ASSERT(input0 != nullptr);
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
for (int i = 0; i < kLstmOutputNum; i++) {

View File

@ -186,7 +186,6 @@ int Pooling::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
MS_ASSERT(pooling_prim != nullptr);
auto window_h = GetWindowH();
auto window_w = GetWindowW();
if (GetGlobal()) {

View File

@ -132,7 +132,7 @@ constexpr int kPriorBoxW = 1;
constexpr int kPriorBoxC = 2;
} // namespace
int PriorBox::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(param != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
auto output = outputs_.at(0);
@ -144,7 +144,6 @@ int PriorBox::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> ou
}
std::vector<float> different_aspect_ratios{1.0f};
auto aspect_ratios = GetAspectRatios();
MS_ASSERT(aspect_ratios != nullptr);
for (size_t i = 0; i < aspect_ratios.size(); i++) {
float ratio = aspect_ratios[i];
bool exist = std::any_of(different_aspect_ratios.begin(), different_aspect_ratios.end(),

View File

@ -59,7 +59,7 @@ int QuantDTypeCast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
MS_ASSERT(input->data_type() == param->srcT);
MS_ASSERT(input->data_type() == this->GetSrcT());
output->set_data_type(static_cast<TypeId>(GetDstT()));
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {

View File

@ -64,8 +64,6 @@ int Range::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
MS_ASSERT(range_prim != nullptr);
output->set_data_type(input->data_type());
output->SetFormat(input->GetFormat());
if (!GetInferFlag()) {

View File

@ -176,7 +176,6 @@ int Reshape::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
return RET_OK;
}
MS_ASSERT(reshape_prim != nullptr);
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);

View File

@ -54,7 +54,6 @@ Registry SparseToDenseRegistry(schema::PrimitiveType_SparseToDense, SparseToDens
int SparseToDense::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
MS_ASSERT(output_shape != nullptr);
auto output = outputs_.front();
if (output == nullptr) {
MS_LOG(ERROR) << "output null pointer dereferencing.";

View File

@ -104,7 +104,6 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
MS_ASSERT(spilt_prim != nullptr);
if (inputs_.size() != kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
return RET_ERROR;

View File

@ -138,7 +138,6 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_OK;
}
MS_ASSERT(tile_prim != nullptr);
std::vector<int> out_shape;
std::vector<int> multiples = GetMultiples();
const size_t in_dims = input->shape().size();

View File

@ -71,7 +71,6 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
if (!GetInferFlag()) {
return RET_OK;
}
MS_ASSERT(topk_prim != nullptr);
auto out_shape = input->shape();
out_shape[out_shape.size() - 1] = GetK();
if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) {

View File

@ -82,7 +82,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
fp16_bias_data[i] = (float16_t)ori_bias[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}

View File

@ -97,7 +97,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
fp16_bias_data[i] = (float16_t)ori_bias[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}

View File

@ -346,7 +346,7 @@ int DeConvWinogradFp16CPUKernel::InitDataParam() {
fp16_bias_data[i] = (float16_t)src_bias[i];
}
} else {
MS_ASSERT(inputs_.size() == kInputSize1);
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;

View File

@ -127,8 +127,8 @@ static int TransposeFp16Run(void *cdata, int task_id) {
}
int TransposeFp16CPUKernel::Run() {
MS_ASSERT(in_tensors_.size() == TransposeInputNum);
MS_ASSERT(out_tensors_.size() == TransposeOutputNum);
MS_ASSERT(in_tensors_.size() == 1);
MS_ASSERT(out_tensors_.size() == 1);
auto &in_tensor = in_tensors_.front();
auto &out_tensor = out_tensors_.front();
if (in_tensor == nullptr || out_tensor == nullptr) {

View File

@ -107,7 +107,7 @@ int PReluCPUKernel::ProcessShareChannelInput() {
}
int PReluCPUKernel::Run() {
MS_ASSERT(in_shape.size() >= 2);
MS_ASSERT(in_tensors_.size() >= 2);
auto input_tensor = in_tensors_[0];
ori_input_ = reinterpret_cast<float *>(input_tensor->MutableData());
output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());

View File

@ -77,7 +77,7 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::Tensor *> &
MS_LOG(ERROR) << "input parameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == PrimitiveType_Tile);
MS_ASSERT(desc.type == PrimitiveType_TopK);
auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new TopKCPUKernel fail!";

View File

@ -116,8 +116,8 @@ int TransposeFp32Run(void *cdata, int task_id) {
}
int TransposeCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() == TransposeInputNum);
MS_ASSERT(out_tensors_.size() == TransposeOutputNum);
MS_ASSERT(in_tensors_.size() == 1);
MS_ASSERT(out_tensors_.size() == 1);
auto &in_tensor = in_tensors_.front();
auto &out_tensor = out_tensors_.front();
if (in_tensor == nullptr || out_tensor == nullptr) {

View File

@ -108,7 +108,7 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) {
auto element_num = out_tensors_[0]->ElementsNum();
auto param = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
if (param->broadcasting_ && arithmetic_run_ != nullptr) {
MS_ASSERT(opParameter->thread_num_ != 0);
MS_ASSERT(op_parameter_->thread_num_ != 0);
int stride = UP_DIV(element_num, op_parameter_->thread_num_);
int count = MSMIN(stride, element_num - stride * thread_id);
if (count <= 0) {

View File

@ -94,7 +94,6 @@ int LeakyReluInt8CPUKernel::ReSize() {
auto *input_tensor = in_tensors_.at(kInputIndex);
auto *out_tensor = out_tensors_.at(kOutputIndex);
auto input_dim = input_tensor->shape().size();
MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE);
quant_prelu_parm_.input_dim_ = input_dim;
quant_prelu_parm_.element_num = in_tensors_[0]->Size();
auto input_shape = input_tensor->shape();

View File

@ -372,7 +372,7 @@ int ReduceMeanPatternInt8Impl(void *cdata, int task_id) {
}
void ReduceInt8CPUKernel::GetQuantArgs(size_t i) {
MS_ASSERT(i < static_cast<size_t>(num_axis_));
MS_ASSERT(i < static_cast<size_t>(num_axes_));
if (mode_ == static_cast<int>(schema::ReduceMode_ReduceMean)) {
quant_arg_.mean_multiplier_ = mean_multipliers_[i]->multiplier_;
quant_arg_.mean_left_shift_ = mean_multipliers_[i]->left_shift_;

View File

@ -44,7 +44,7 @@ int SplitInt8CPUKernel::Init() {
auto in_quant_args = in_tensor->GetQuantParams();
param->quant_arg_.in_args_.scale_ = in_quant_args.front().scale;
param->quant_arg_.in_args_.zp_ = in_quant_args.front().zeroPoint;
MS_ASSERT(param->num_split_ == outputs_.size());
MS_ASSERT(param->num_split_ == this->out_tensors_.size());
for (int i = 0; i < param->num_split_; i++) {
auto *out_tensor = out_tensors_.at(i);
auto out_quant_args = out_tensor->GetQuantParams();
@ -91,7 +91,7 @@ int SplitInt8Run(void *cdata, int task_id) {
int SplitInt8CPUKernel::Run() {
auto in_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<int8_t *>(in_tensor->MutableData());
MS_ASSERT(param->num_split_ == outputs_.size());
MS_ASSERT(param->num_split_ == this->out_tensors_.size());
for (int i = 0; i < param->num_split_; i++) {
output_ptr_[i] = reinterpret_cast<int8_t *>(out_tensors_.at(i)->data_c());
}

View File

@ -74,7 +74,7 @@ int SqueezeInt8CPUKernel::Init() {
quant_Squeeze_parm_->in_quant_args_[i].zp_ = quant_args.front().zeroPoint;
}
MS_ASSERT(outputs_.size() == 1);
MS_ASSERT(this->out_tensors_.size() == 1);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor != nullptr);
auto quant_args = output_tensor->GetQuantParams();
@ -94,7 +94,6 @@ int SqueezeInt8CPUKernel::ReSize() {
auto *input_tensor = in_tensors_.at(i);
MS_ASSERT(input_tensor != nullptr);
auto input_size = input_tensor->shape().size();
MS_ASSERT(input_size != NULL);
quant_Squeeze_parm_->input_shapes_[i] = reinterpret_cast<int *>(malloc(sizeof(int) * input_size));
if (quant_Squeeze_parm_->input_shapes_[i] == nullptr) {
MS_LOG(ERROR) << "Null pointer reference: quant_Squeeze_parm_->input_shapes_[" << i << "].";
@ -113,7 +112,6 @@ int SqueezeInt8CPUKernel::ReSize() {
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor != nullptr);
auto output_shape = output_tensor->shape();
MS_ASSERT(output_shape != NULL);
auto output_dim = output_shape.size();
quant_Squeeze_parm_->output_dim_ = output_dim;
int output_size = 1;

View File

@ -100,8 +100,6 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te
out_tensors->clear();
out_parameters->clear();
out_convert_ops->clear();
MS_ASSERT(in_tensors.size() == to_kernels.size());
MS_ASSERT(in_tensors.size() == from_kernels.size());
std::vector<std::vector<kernel::LiteKernel *>> loop_kernels;
if (mem_type == MemType::BUF) {
GetKernelFromToTensor(in_tensors, nodes_, &loop_kernels, true);

View File

@ -307,7 +307,7 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector<Tensor *> &in_ten
return dtype;
}
}
MS_ASSERT(in_tensors.size() > 0);
MS_ASSERT(!in_tensors.empty());
return in_tensors[0]->data_type();
}

View File

@ -53,11 +53,11 @@ std::string SubGraphKernel::ToString() const {
oss << " " << tensor;
}
oss << std::endl << "Subgraph input kernels :" << std::endl;
for (auto kernel : this->in_kernels_) {
for (auto kernel : this->in_nodes_) {
oss << " " << kernel->ToString() << std::endl;
}
oss << std::endl << "Subgraph output kernels :" << std::endl;
for (auto kernel : this->out_kernels_) {
for (auto kernel : this->out_nodes_) {
oss << " " << kernel->ToString() << std::endl;
}
oss << std::endl << nodes_.size() << " nodes in subgraph :";
@ -72,12 +72,7 @@ int SubGraphKernel::Run() {
MS_LOG(ERROR) << "executor is nullptr";
return RET_ERROR;
}
auto ret = executor_->Prepare(this->nodes_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed: " << ret;
return ret;
}
ret = executor_->Run(this->in_tensors_, this->out_tensors_, this->nodes_, this->context_->allocator.get());
auto ret = executor_->Run(this->in_tensors_, this->out_tensors_, this->nodes_, this->context_->allocator.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run sub graph failed: " << ret;
return ret;
@ -90,12 +85,7 @@ int SubGraphKernel::Run(const KernelCallBack &before, const KernelCallBack &afte
MS_LOG(ERROR) << "executor is nullptr";
return RET_ERROR;
}
auto ret = executor_->Prepare(this->nodes_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed: " << ret;
return ret;
}
ret =
auto ret =
executor_->Run(this->in_tensors_, this->out_tensors_, this->nodes_, this->context_->allocator.get(), before, after);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run sub graph failed: " << ret;
@ -168,9 +158,61 @@ int CpuSubGraph::Prepare() {
return RET_OK;
}
int CpuFp32SubGraph::PreProcess() { return RET_OK; }
void CpuFp16SubGraph::FreeOriginInputData() {
for (auto *data_store : this->origin_input_data_) {
MS_ASSERT(data_store != nullptr);
// free data in data_store
if (data_store->data_ != nullptr) {
if (data_store->allocator_ == nullptr) {
free(data_store->data_);
} else {
data_store->allocator_->Free(data_store->data_);
}
}
// free data_store
if (this->context_->allocator != nullptr) {
this->context_->allocator->Free(data_store);
} else {
free(data_store);
}
data_store = nullptr;
}
this->origin_input_data_.clear();
}
int CpuFp16SubGraph::PreProcess() {
auto fp32_to_fp16_cast_func = Float16CastUtil::GetInstance()->float32_to_float16_func_;
if (fp32_to_fp16_cast_func == nullptr) {
MS_LOG(ERROR) << "Can not find cast fp32 to fp16 func";
return RET_ERROR;
}
MS_ASSERT(origin_input_data_.empty());
for (auto tensor : this->in_tensors_) {
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat32) {
auto float32_data = tensor->data_c();
MS_ASSERT(float32_data != nullptr);
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat16);
auto ret = tensor->MallocData();
if (RET_OK != ret) {
MS_LOG(ERROR) << "malloc data failed";
this->FreeOriginInputData();
return RET_ERROR;
}
MS_ASSERT(tensor->data_c() != nullptr);
fp32_to_fp16_cast_func(float32_data, tensor->data_c(), tensor->ElementsNum());
auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get());
if (data_store == nullptr) {
MS_LOG(ERROR) << "Create DataStore failed";
this->FreeOriginInputData();
return RET_ERROR;
}
origin_input_data_.emplace_back(data_store);
} else {
origin_input_data_.emplace_back(nullptr);
}
}
for (auto kernel : this->nodes_) {
for (auto tensor : kernel->out_tensors()) {
if (tensor->data_type() == kNumberTypeFloat32) {
@ -188,6 +230,7 @@ int CpuFp16SubGraph::PostProcess() {
return RET_ERROR;
}
for (auto tensor : this->out_tensors_) {
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat16) {
auto float16_data = tensor->data_c();
MS_ASSERT(float16_data != nullptr);
@ -212,6 +255,21 @@ int CpuFp16SubGraph::PostProcess() {
}
}
}
MS_ASSERT(this->origin_input_data_.size() == this->in_tensors_.size());
for (size_t i = 0; i < this->in_tensors_.size(); i++) {
auto tensor = in_tensors_.at(i);
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->FreeData();
auto origin_tensor_data = origin_input_data_.at(i);
MS_ASSERT(origin_tensor_data != nullptr);
MS_ASSERT(origin_tensor_data->data_ != nullptr);
tensor->set_data(origin_tensor_data->data_);
tensor->set_data_type(kNumberTypeFloat32);
origin_tensor_data->data_ = nullptr;
}
}
this->FreeOriginInputData();
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -22,6 +22,7 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/executor.h"
#include "src/common/log_adapter.h"
#ifdef ENABLE_ARM64
#include "nnacl/optimized_kernel.h"
#endif
@ -58,14 +59,37 @@ class Float16CastUtil {
Float16CastFunc float32_to_float16_func_ = nullptr;
};
// store origin data and allocator of input tensor of subgraph for PreProcess and PostProcess
struct DataStore {
void *data_ = nullptr;
lite::Allocator *allocator_ = nullptr;
static DataStore *CreateDataStore(void *data = nullptr, lite::Allocator *data_allocator = nullptr,
lite::Allocator *allocator = nullptr) {
DataStore *tensor_data = nullptr;
if (allocator == nullptr) {
tensor_data = static_cast<DataStore *>(malloc(sizeof(DataStore)));
} else {
tensor_data = static_cast<DataStore *>(allocator->Malloc(sizeof(DataStore)));
}
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "Malloc tensor_data failed";
return nullptr;
}
tensor_data->data_ = data;
tensor_data->allocator_ = data_allocator;
return tensor_data;
}
};
class SubGraphKernel : public LiteKernel {
public:
explicit SubGraphKernel(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<LiteKernel *> &in_kernels, const std::vector<LiteKernel *> &out_kernels,
std::vector<LiteKernel *> nodes, const lite::InnerContext *ctx)
: LiteKernel(nullptr, inputs, outputs, ctx, nullptr), nodes_(std::move(nodes)) {
in_kernels_ = in_kernels;
out_kernels_ = out_kernels;
: LiteKernel(nullptr, inputs, outputs, ctx, nullptr),
nodes_(std::move(nodes)),
in_nodes_(in_kernels),
out_nodes_(out_kernels) {
subgraph_type_ = kCpuFP32SubGraph;
}
@ -97,6 +121,10 @@ class SubGraphKernel : public LiteKernel {
protected:
std::vector<LiteKernel *> nodes_;
// entry nodes in nodes
std::vector<LiteKernel *> in_nodes_;
// exit nodes in nodes
std::vector<LiteKernel *> out_nodes_;
mindspore::lite::Executor *executor_ = nullptr;
};
@ -119,7 +147,7 @@ class CpuSubGraph : public SubGraphKernel {
int Run(const KernelCallBack &before, const KernelCallBack &after) override {
return SubGraphKernel::Run(before, after);
};
int PostProcess() override { return mindspore::lite::RET_OK; }
int PostProcess() override { return SubGraphKernel::PostProcess(); }
};
class CpuFp32SubGraph : public CpuSubGraph {
@ -134,12 +162,12 @@ class CpuFp32SubGraph : public CpuSubGraph {
~CpuFp32SubGraph() override = default;
int Init() override { return mindspore::lite::RET_ERROR; }
int PreProcess() override;
int PreProcess() override { return CpuSubGraph::PreProcess(); }
int Run() override { return CpuSubGraph::Run(); }
int Run(const KernelCallBack &before, const KernelCallBack &after) override {
return CpuSubGraph::Run(before, after);
};
int PostProcess() override { return mindspore::lite::RET_OK; }
int PostProcess() override { return CpuSubGraph::PostProcess(); }
};
class CpuFp16SubGraph : public CpuSubGraph {
@ -160,6 +188,12 @@ class CpuFp16SubGraph : public CpuSubGraph {
return CpuSubGraph::Run(before, after);
};
int PostProcess() override;
private:
void FreeOriginInputData();
private:
std::vector<DataStore *> origin_input_data_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_SUB_GRAPH_H