fix bugs to support ips location

This commit is contained in:
yangjie159 2021-04-01 17:38:49 +08:00
parent ae91575346
commit 2842d633b4
16 changed files with 205 additions and 183 deletions

View File

@ -170,6 +170,8 @@ set(LITE_SRC
${LITE_DIR}/src/ops/populate/erf_populate.cc
${LITE_DIR}/src/ops/populate/exp_populate.cc
${LITE_DIR}/src/ops/populate/strided_slice_populate.cc
${LITE_DIR}/src/ops/populate/lstm_populate.cc
${LITE_DIR}/src/ops/populate/squeeze_populate.cc
### tools
${LITE_DIR}/tools/common/flag_parser.cc
)

View File

@ -19,6 +19,7 @@
#include <map>
#include "coder/allocator/memory_manager.h"
#include "coder/opcoders/op_coder.h"
#include "coder/generator/component/component.h"
namespace mindspore::lite::micro {
void *MemoryAllocator::MallocWeightTensor(TypeId type_id, size_t size, MallocType type) {
@ -27,12 +28,13 @@ void *MemoryAllocator::MallocWeightTensor(TypeId type_id, size_t size, MallocTyp
{kNumberTypeInt16, sizeof(int16_t)}, {kNumberTypeInt8, sizeof(int8_t)}, {kNumberTypeUInt8, sizeof(uint8_t)}};
auto item = size_map.find(type_id);
MS_CHECK_TRUE_RET_NULL(item != size_map.end(), "unsupported type idnex");
size_t type_size = item->second;
std::vector<int> shape = {1, static_cast<int>(size / type_size)};
auto cate = type == kOfflinePackWeight ? Tensor::Category::CONST_TENSOR : Tensor::Category::VAR;
Tensor *weight = new (std::nothrow) lite::Tensor(type_id, shape, schema::Format_NHWC, cate);
MS_CHECK_PTR_RET_NULL(weight);
std::string runtime_addr = net_weight_addr_ + std::to_string(weight_index_++);
std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++);
malloc_weights_addr_.insert(std::make_pair(weight, runtime_addr));
if (type == kOfflinePackWeight) {
saved_weights_addr_.insert(std::make_pair(runtime_addr, weight));
@ -41,13 +43,6 @@ void *MemoryAllocator::MallocWeightTensor(TypeId type_id, size_t size, MallocTyp
return weight->data_c();
}
void MemoryAllocator::RecordRuntimeAddrs(const std::string &net_input_addr, const std::string &net_buffer_addr,
const std::string &net_weight_addr) {
net_input_addr_ = net_input_addr;
net_buffer_addr_ = net_buffer_addr;
net_weight_addr_ = net_weight_addr;
}
void MemoryAllocator::Free() {
for (auto iter = malloc_weights_addr_.begin(); iter != malloc_weights_addr_.end();) {
Tensor *tensor = iter->first;
@ -78,7 +73,7 @@ void MemoryAllocator::AssignWorkspaces(void *addr, size_t size) {
is_next_ = false;
offset_ = 0;
}
workspaces_addr_.insert(std::make_pair(addr, net_buffer_addr_ + "+" + std::to_string(tensors_size_ + offset_)));
workspaces_addr_.insert(std::make_pair(addr, kBufferPrefixNameAdd + std::to_string(tensors_size_ + offset_)));
offset_ += size;
if (workspace_size_ < offset_) {
workspace_size_ = offset_;
@ -89,14 +84,14 @@ void MemoryAllocator::RecordTensorsAddr(const std::map<Tensor *, size_t> &offset
for (auto &item : offsets) {
auto tensor = item.first;
auto offset = item.second;
tensors_addr_.insert(std::make_pair(tensor, net_buffer_addr_ + "+" + std::to_string(offset)));
tensors_addr_.insert(std::make_pair(tensor, kBufferPrefixNameAdd + std::to_string(offset)));
}
}
void MemoryAllocator::AssignGraphInputs(const std::vector<Tensor *> &inputs) {
size_t num = inputs.size();
for (size_t i = 0; i < num; ++i) {
tensors_addr_.insert(std::make_pair(inputs.at(i), net_input_addr_ + std::to_string(i)));
tensors_addr_.insert(std::make_pair(inputs.at(i), kInputPrefixName + std::to_string(i)));
}
}
@ -106,7 +101,7 @@ void MemoryAllocator::RecordOriginWeightsAddr(const std::vector<std::unique_ptr<
for (const auto &tensor : inputs) {
if (tensor->category() == Tensor::Category::CONST_TENSOR ||
tensor->category() == Tensor::Category::CONST_SCALAR) {
std::string runtime_addr = net_weight_addr_ + std::to_string(weight_index_);
std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_);
origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr));
weight_index_++;
}

View File

@ -52,11 +52,6 @@ class MemoryAllocator {
MemoryAllocator(const MemoryAllocator &) = delete;
MemoryAllocator &operator=(const MemoryAllocator &) = delete;
/*
* Record Runtime's addrs of model
*/
void RecordRuntimeAddrs(const std::string &net_input_addr, const std::string &net_buffer_addr,
const std::string &net_weight_addr);
/*
* assign model's input, original weights and all tensors memory addr
*/
@ -76,7 +71,6 @@ class MemoryAllocator {
if (size == 0 || size >= UINT_MAX) {
return nullptr;
}
void *buffer = malloc(size);
if (buffer == nullptr) {
MS_LOG(ERROR) << "malloc memory failed";
@ -109,28 +103,24 @@ class MemoryAllocator {
return type_info + wrap(item->second);
}
auto iter = std::find_if(
tensors_addr_.begin(), tensors_addr_.end(),
[&variable](const std::pair<Tensor *, std::string> &a) { return variable == reinterpret_cast<void *>(a.first); });
if (iter != tensors_addr_.end()) {
return type_info + wrap(iter->second);
}
// find variable in weights map
iter =
auto iter =
std::find_if(malloc_weights_addr_.begin(), malloc_weights_addr_.end(),
[&variable](const std::pair<Tensor *, std::string> &a) { return variable == (a.first)->data_c(); });
if (iter != malloc_weights_addr_.end()) {
return iter->second;
}
// origin weight
iter = std::find_if(origin_weights_addr_.begin(), origin_weights_addr_.end(),
[&variable](const std::pair<Tensor *, std::string> &a) { return variable == a.first; });
if (iter != origin_weights_addr_.end()) {
saved_weights_addr_.insert(std::make_pair(iter->second, reinterpret_cast<Tensor *>(variable)));
if (immutable) {
malloc_weights_addr_.insert({reinterpret_cast<Tensor *>(variable), iter->second});
}
return iter->second;
Tensor *tensor = reinterpret_cast<Tensor *>(t);
auto it = tensors_addr_.find(tensor);
if (it != tensors_addr_.end()) {
return type_info + wrap(it->second);
}
it = origin_weights_addr_.find(tensor);
if (it != origin_weights_addr_.end()) {
saved_weights_addr_.insert(std::make_pair(it->second, tensor));
if (immutable) malloc_weights_addr_.insert(std::make_pair(tensor, it->second));
return it->second;
}
MS_LOG(ERROR) << "uninitialized memory";
return "";
@ -153,9 +143,7 @@ class MemoryAllocator {
void RecordOriginWeightsAddr(const std::vector<std::unique_ptr<OperatorCoder>> &nodes);
void RecordTensorsAddr(const std::map<Tensor *, size_t> &offsets);
private:
MemoryAllocator() = default;
~MemoryAllocator() = default;
std::map<void *, std::string> workspaces_addr_;
@ -170,9 +158,6 @@ class MemoryAllocator {
std::map<Tensor *, std::string> origin_weights_addr_;
std::map<Tensor *, std::string> malloc_weights_addr_;
std::map<Tensor *, std::string> tensors_addr_;
std::string net_input_addr_;
std::string net_buffer_addr_;
std::string net_weight_addr_;
};
} // namespace mindspore::lite::micro
#endif // MINDSPORE_LITE_MICRO_CODER_MEMORY_ALLOCATOR_H_

View File

@ -15,14 +15,14 @@
*/
#include "coder/context.h"
#include "coder/allocator/allocator.h"
#include "coder/generator/component/component.h"
namespace mindspore::lite::micro {
CoderContext::CoderContext() {
this->input_name_ = "g_Input";
this->output_name_ = "g_Output";
this->buffer_name_ = "g_Buffer";
this->weight_name_ = "g_Weight";
this->input_name_ = kInputPrefixName;
this->output_name_ = kOutputPrefixName;
this->buffer_name_ = kBufferPrefixName;
this->weight_name_ = kWeightPrefixName;
}
void CoderContext::AppendCode(const std::string &codeBlock) { this->code_blocks_.emplace_back(codeBlock); }

View File

@ -19,6 +19,12 @@
namespace mindspore::lite::micro {
constexpr auto kInputPrefixName = "g_Input";
constexpr auto kOutputPrefixName = "g_Output";
constexpr auto kWeightPrefixName = "g_Weight";
constexpr auto kBufferPrefixName = "g_Buffer";
constexpr auto kBufferPrefixNameAdd = "g_Buffer + ";
constexpr auto kModelName = "net";
constexpr auto kSourcePath = "/src/";
@ -39,7 +45,7 @@ constexpr auto kExternCpp =
"extern \"C\" {\n"
"#endif\n";
constexpr char kEndExternCpp[] =
constexpr auto kEndExternCpp =
"#ifdef __cplusplus\n"
"}\n"
"#endif\n";

View File

@ -109,10 +109,12 @@ int StridedSliceBaseCoder::DoFastCode(CoderContext *ctx) {
if (cur_outer > cal_num_per_thread_) {
cur_outer = cal_num_per_thread_;
}
code << "uint8_t *cur_in_ptr = " << input_ptr_str << " + "
<< (caled_num * in_shape[split_axis_] + begin_index) * inner_size_ << ";\n";
code << " uint8_t *cur_out_ptr = " << output_ptr_str << " + " << caled_num * out_shape[split_axis_] * inner_size_
<< ";\n";
code << "uint8_t *cur_in_ptr = "
<< "(uint8_t *)(" << input_ptr_str << ")"
<< " + " << (caled_num * in_shape[split_axis_] + begin_index) * inner_size_ << ";\n";
code << " uint8_t *cur_out_ptr = "
<< "(uint8_t *)(" << output_ptr_str << ")"
<< " + " << caled_num * out_shape[split_axis_] * inner_size_ << ";\n";
code.CodeFunction("FastStride", "cur_in_ptr", "cur_out_ptr", out_shape.at(split_axis_),
strided_slice_parameter_->strides_[split_axis_], cur_outer, inner_size_,
in_shape.at(split_axis_) * inner_size_);
@ -124,9 +126,12 @@ int StridedSliceBaseCoder::DoFastCode(CoderContext *ctx) {
if (cal_axis_num > cal_num_per_thread_) {
cal_axis_num = cal_num_per_thread_;
}
code << "uint8_t *cur_in_ptr = " << input_ptr_str << " + "
<< (caled_num * strided_slice_parameter_->strides_[split_axis_] + begin_index) * inner_size_ << ";\n";
code << "uint8_t *cur_out_ptr = " << output_ptr_str << " + " << caled_num * inner_size_ << ";\n";
code << "uint8_t *cur_in_ptr = "
<< "(uint8_t *)(" << input_ptr_str << ")"
<< " + " << (caled_num * strided_slice_parameter_->strides_[split_axis_] + begin_index) * inner_size_ << ";\n";
code << "uint8_t *cur_out_ptr = "
<< "(uint8_t *)(" << output_ptr_str << ")"
<< " + " << caled_num * inner_size_ << ";\n";
code.CodeFunction("FastStride", "cur_in_ptr", "cur_out_ptr", cal_axis_num,
strided_slice_parameter_->strides_[split_axis_], 1, inner_size_, 0);
}

View File

@ -24,8 +24,8 @@
using mindspore::schema::PrimitiveType_LSTM;
namespace mindspore::lite::micro::nnacl {
constexpr int kFifthIndex = 5;
constexpr int kSixthIndex = 6;
constexpr int kFifthIndex = 4;
constexpr int kSixthIndex = 5;
int LstmFP32Coder::InitInputWeightBias(CoderContext *const context) {
NNaclFp32Serializer init_code;
@ -33,7 +33,7 @@ int LstmFP32Coder::InitInputWeightBias(CoderContext *const context) {
MS_CHECK_PTR(weight_i);
size_t weight_i_size = weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float);
weight_i_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight));
MS_CHECK_PTR(weight_h_ptr_);
MS_CHECK_PTR(weight_i_ptr_);
init_code.CodeMallocExpression(weight_i_ptr_, weight_i_size);
init_code.CodeFunction("memset", weight_i_ptr_, 0, weight_i_size);
@ -168,6 +168,7 @@ int LstmFP32Coder::DoCode(CoderContext *context) {
},
{
"lstm_fp32.c",
"mul_fp32.c",
});
Tensor *hidden_state = input_tensors_.at(kFifthIndex);
@ -179,12 +180,18 @@ int LstmFP32Coder::DoCode(CoderContext *context) {
Tensor *output_cell_state = output_tensors_[2];
MS_CHECK_PTR(output_hidden_state);
std::vector<std::string> buffers_addr;
for (const auto &buf : buffer_) {
std::string addr = buf == nullptr ? "NULL" : allocator_->GetRuntimeAddr(buf);
buffers_addr.push_back(addr);
}
NNaclFp32Serializer code;
code.CodeStruct("lstm_param", *lstm_param_);
code.CodeArray("buffer", buffers_addr.data(), buffers_addr.size(), false);
code.CodeFunction("memcpy", output_hidden_state, hidden_state, hidden_state->Size());
code.CodeFunction("memcpy", output_cell_state, cell_state, cell_state->Size());
code.CodeFunction("Lstm", output_tensor_, input_tensor_, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_,
output_hidden_state, output_cell_state, buffer_, lstm_param_);
output_hidden_state, output_cell_state, "buffer", "&lstm_param");
context->AppendCode(code.str());
return RET_OK;
}

View File

@ -16,7 +16,6 @@
#include "coder/opcoders/nnacl/fp32/transpose_fp32_coder.h"
#include <vector>
#include <string>
#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
#include "coder/opcoders/file_collector.h"
#include "coder/opcoders/parallel.h"
@ -25,76 +24,129 @@ using mindspore::schema::PrimitiveType_Transpose;
namespace mindspore::lite::micro::nnacl {
int TransposeFp32Coder::Resize() {
num_unit_ = static_cast<int>(input_tensor_->shape().at(transpose_parameter_->perm_[kNHWC_H]));
thread_h_num_ = MSMIN(thread_num_, num_unit_);
MS_CHECK_TRUE(thread_h_num_ > 0, "thread_h_num_ <= 0");
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
std::vector<int> in_shape = input_tensor_->shape();
std::vector<int> out_shape = output_tensor_->shape();
transpose_parameter_->strides_[transpose_parameter_->num_axes_ - 1] = 1;
transpose_parameter_->out_strides_[transpose_parameter_->num_axes_ - 1] = 1;
transpose_parameter_->data_size_ = static_cast<int>(input_tensor_->Size());
for (int i = transpose_parameter_->num_axes_ - 2; i >= 0; i--) {
transpose_parameter_->strides_[i] = in_shape.at(i + 1) * transpose_parameter_->strides_[i + 1];
transpose_parameter_->out_strides_[i] = out_shape.at(i + 1) * transpose_parameter_->out_strides_[i + 1];
if (input_tensors_.size() == 2) {
param_->num_axes_ = input_tensors_.at(1)->ElementsNum();
}
MS_CHECK_TRUE(in_shape.size() > 0, "invalid shape size");
MS_CHECK_TRUE(out_shape.size() > 0, "invalid shape size");
auto in_shape_data_size = static_cast<size_t>(in_shape.size() * sizeof(int));
auto out_shape_data_size = static_cast<size_t>(out_shape.size() * sizeof(int));
in_shape_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, in_shape_data_size, kOfflinePackWeight));
MS_CHECK_PTR(in_shape_);
if (input_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param_->num_axes_)) {
return RET_OK;
}
// get perm data
MS_ASSERT(input_tensors_.size() == 2);
auto perm_tensor = input_tensors_.at(1);
int *perm_data = reinterpret_cast<int *>(perm_tensor->data_c());
MS_ASSERT(perm_data != nullptr);
for (int i = 0; i < param_->num_axes_; ++i) {
param_->perm_[i] = perm_data[i];
}
auto in_shape = input_tensor_->shape();
auto out_shape = output_tensor_->shape();
param_->strides_[param_->num_axes_ - 1] = 1;
param_->out_strides_[param_->num_axes_ - 1] = 1;
param_->data_size_ = input_tensor_->Size();
for (int i = param_->num_axes_ - 2; i >= 0; i--) {
param_->strides_[i] = in_shape.at(i + 1) * param_->strides_[i + 1];
param_->out_strides_[i] = out_shape.at(i + 1) * param_->out_strides_[i + 1];
}
out_shape_ =
reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt32, out_shape.size() * sizeof(int), kOfflinePackWeight));
MS_CHECK_PTR(out_shape_);
MS_CHECK_RET_CODE(memcpy_s(in_shape_, in_shape_data_size, in_shape.data(), in_shape_data_size), "memcpy failed");
MS_CHECK_RET_CODE(memcpy_s(out_shape_, out_shape_data_size, out_shape.data(), out_shape_data_size), "memcpy failed");
memcpy(out_shape_, out_shape.data(), in_shape.size() * sizeof(int));
return RET_OK;
}
int TransposeFp32Coder::Init() {
transpose_parameter_ = reinterpret_cast<TransposeParameter *>(parameter_);
MS_CHECK_PTR(transpose_parameter_);
param_ = reinterpret_cast<TransposeParameter *>(parameter_);
MS_CHECK_PTR(param_);
return Resize();
}
int TransposeFp32Coder::Prepare(CoderContext *const context) {
MS_CHECK_RET_CODE(Init(), "init failed");
int out_dims = static_cast<int>(output_tensor_->shape().size());
auto out_data_dims_size = static_cast<size_t>(out_dims * thread_h_num_ * sizeof(int));
if (out_dims > MAX_TRANSPOSE_DIM_SIZE) {
dim_size_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, out_data_dims_size, kWorkspace));
MS_CHECK_PTR(dim_size_);
position_ = reinterpret_cast<int *>(allocator_->Malloc(kNumberTypeInt, out_data_dims_size, kWorkspace));
MS_CHECK_PTR(position_);
}
return RET_OK;
}
int TransposeFp32Coder::DoCode(CoderContext *const context) {
int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - kDefaultTaskId * thread_h_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
void TransposeFp32Coder::GetNHNCTransposeFunc() {
auto out_shape = output_tensor_->shape();
if (input_tensor_->shape().size() == 4 && param_->perm_[0] == 0 && param_->perm_[1] == 2 && param_->perm_[2] == 3 &&
param_->perm_[3] == 1) {
nhnc_param_[0] = out_shape[0];
nhnc_param_[1] = out_shape[1] * out_shape[2];
nhnc_param_[2] = out_shape[3];
if (input_tensor_->data_type() == kNumberTypeFloat32) {
NHNCTransposeFunc_ = "PackNCHWToNHWCFp32";
}
}
if (input_tensor_->shape().size() == 4 && param_->perm_[0] == 0 && param_->perm_[1] == 3 && param_->perm_[2] == 1 &&
param_->perm_[3] == 2) {
nhnc_param_[0] = out_shape[0];
nhnc_param_[1] = out_shape[2] * out_shape[3];
nhnc_param_[2] = out_shape[1];
if (input_tensor_->data_type() == kNumberTypeFloat32) {
NHNCTransposeFunc_ = "PackNHWCToNCHWFp32";
}
}
}
int TransposeFp32Coder::DoCode(CoderContext *const context) {
MS_ASSERT(in_tensors_.size() == 1 || in_tensors_.size() == 2);
MS_ASSERT(out_tensors_.size() == 1);
Collect(context,
{
"nnacl/transpose.h",
"nnacl/fp32/transpose.h",
"nnacl/errorcode.h",
"nnacl/fp32/transpose_fp32.h",
},
{
"transpose.c",
"transpose_fp32.c",
});
NNaclFp32Serializer code;
code.CodeStruct("transpose_parameter", *transpose_parameter_);
code.CodeFunction("DoTransposeFp32", input_tensor_, output_tensor_, in_shape_, out_shape_,
"(TransposeParameter *)&transpose_parameter", kDefaultTaskId, num_unit_thread, dim_size_,
position_);
if (input_tensor_->shape().size() != static_cast<size_t>(param_->num_axes_)) {
code.CodeFunction("memcpy", output_tensor_, input_tensor_, input_tensor_->Size());
context->AppendCode(code.str());
return RET_OK;
}
if (input_tensors_.size() == 2) {
auto input_perm = input_tensors_.at(1);
MS_ASSERT(input_perm != nullptr);
MS_ASSERT(input_perm->data_c() != nullptr);
int *perm_data = reinterpret_cast<int *>(input_perm->data_c());
for (int i = 0; i < input_perm->ElementsNum(); ++i) {
param_->perm_[i] = perm_data[i];
}
for (int i = input_perm->ElementsNum(); i < MAX_SHAPE_SIZE; ++i) {
param_->perm_[i] = 0;
}
}
GetNHNCTransposeFunc();
if (!NHNCTransposeFunc_.empty()) {
code.CodeFunction(NHNCTransposeFunc_, input_tensor_, output_tensor_, nhnc_param_[0], nhnc_param_[1], nhnc_param_[2],
kDefaultTaskId, thread_num_);
context->AppendCode(code.str());
return RET_OK;
}
code.CodeStruct("trans_param", *param_);
dims_ = output_tensor_->shape().size();
if (dims_ > MAX_TRANSPOSE_DIM_SIZE) {
int *dim_size = reinterpret_cast<int *>(malloc(dims_ * sizeof(int)));
MS_CHECK_PTR(dim_size);
*(dim_size + dims_ - 1) = 1;
for (int i = dims_ - 1; i > 0; --i) {
*(dim_size + i - 1) = *(dim_size + i) * out_shape_[i];
}
code.CodeArray("dim_size", dim_size, dims_);
int *position = reinterpret_cast<int *>(malloc(dims_ * thread_num_ * sizeof(int)));
MS_CHECK_PTR(position);
code.CodeArray("position", position, dims_ * thread_num_);
code.CodeFunction("TransposeDimsFp32", input_tensor_, output_tensor_, out_shape_, "dim_size", "position",
"&trans_param", kDefaultTaskId, thread_num_);
free(dim_size);
free(position);
} else {
code.CodeFunction("DoTransposeFp32", input_tensor_, output_tensor_, out_shape_, "&trans_param");
}
context->AppendCode(code.str());
return RET_OK;
}

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_FP32_CODER_H_
#define MINDSPORE_LITE_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_FP32_CODER_H_
#include <vector>
#include <string>
#include "coder/opcoders/op_coder.h"
#include "nnacl/transpose.h"
namespace mindspore::lite::micro::nnacl {
@ -38,15 +39,13 @@ class TransposeFp32Coder final : public OperatorCoder {
int Init();
private:
TransposeParameter *transpose_parameter_ = nullptr;
int thread_num_{1};
int thread_h_stride_{0};
int thread_h_num_{0};
int num_unit_{0};
int *in_shape_{nullptr};
void GetNHNCTransposeFunc();
TransposeParameter *param_{nullptr};
int *out_shape_{nullptr};
int *dim_size_{nullptr};
int *position_{nullptr};
std::string NHNCTransposeFunc_;
int nhnc_param_[3];
int dims_{0};
};
} // namespace mindspore::lite::micro::nnacl

View File

@ -52,25 +52,12 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build() {
<< " code_target: " << target_ << " data_type: " << EnumNameDataType(data_type_);
return op_coder;
}
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
ParameterGen paramGen =
PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(node_->primitive_), schema_version);
if (paramGen == nullptr) {
MS_LOG(ERROR) << "parameter generator is null";
return nullptr;
}
OpParameter *parameter = paramGen(node_->primitive_);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< PrimitiveTypeName(GetPrimitiveType(node_->primitive_));
return nullptr;
}
op_coder->set_input_tensor_indices(input_indices_);
op_coder->set_output_tensor_indices(output_indices_);
int thread_num = support_parallel_ ? kMaxThreadNumSupported : 1;
op_coder->set_thread_num(thread_num);
parameter->thread_num_ = thread_num;
op_coder->set_parameter(parameter);
parameter_->thread_num_ = thread_num;
op_coder->set_parameter(parameter_);
op_coder->set_type(primitive_type);
return op_coder;
}
@ -90,6 +77,11 @@ OpCoderBuilder &OpCoderBuilder::node(const Model::Node *node) {
return *this;
}
OpCoderBuilder &OpCoderBuilder::parameter(OpParameter *parameter) {
this->parameter_ = parameter;
return *this;
}
OpCoderBuilder &OpCoderBuilder::data_type(TypeId data_type) {
this->data_type_ = data_type;
return *this;

View File

@ -33,6 +33,8 @@ class OpCoderBuilder {
OpCoderBuilder &node(const Model::Node *node);
OpCoderBuilder &parameter(OpParameter *parameter);
OpCoderBuilder &data_type(TypeId data_type);
OpCoderBuilder &mode(CodeMode mode);
@ -54,6 +56,8 @@ class OpCoderBuilder {
const mindspore::lite::Model::Node *node_ = nullptr;
OpParameter *parameter_{nullptr};
size_t node_index_{0};
Target target_{kTargetUnknown};

View File

@ -102,10 +102,10 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const TileParamete
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposeParameter &transpose_parameter) {
CodeBaseStruct("TransposeParameter", name, transpose_parameter.op_parameter_, ToString(transpose_parameter.perm_),
transpose_parameter.conjugate_, ToString(transpose_parameter.strides_),
ToString(transpose_parameter.out_strides_), transpose_parameter.num_axes_,
transpose_parameter.data_size_);
CodeBaseStruct<false>(
"TransposeParameter", name, transpose_parameter.op_parameter_, ToString(transpose_parameter.perm_),
transpose_parameter.perm_size_, transpose_parameter.conjugate_, ToString(transpose_parameter.strides_),
ToString(transpose_parameter.out_strides_), transpose_parameter.num_axes_, transpose_parameter.data_size_);
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter) {

View File

@ -221,7 +221,7 @@ class Serializer {
if (t == nullptr) {
code << "NULL";
} else {
std::string name = MemoryAllocator::GetInstance()->GetRuntimeAddr(t);
std::string name = MemoryAllocator::GetInstance()->GetRuntimeAddr(t, true);
if (name.empty()) {
MS_LOG(ERROR) << "pointer is not allocated by the allocator";
exit(1);

View File

@ -32,6 +32,7 @@
#include "src/runtime/infer_manager.h"
#include "src/scheduler.h"
#include "include/errorcode.h"
#include "include/model.h"
#include "src/common/file_utils.h"
#include "coder/opcoders/nnacl/dequant/de_quant.h"
@ -39,57 +40,6 @@ namespace mindspore::lite::micro {
CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); }
int CoderSession::InferShape() {
const Model *model = coder_graph_->model();
std::vector<lite::Tensor *> all_tensors = coder_graph_->all_tensors();
size_t nodes_num = model->all_nodes_.size();
for (size_t i = 0; i < nodes_num; ++i) {
auto curr_node = model->all_nodes_.at(i);
if (!curr_node) {
MS_LOG(ERROR) << "model's node is null, who's index is " << i << ". InferShape failed ";
return RET_ERROR;
}
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
size_t input_nums = curr_node->input_indices_.size();
inputs.reserve(input_nums);
for (size_t j = 0; j < input_nums; ++j) {
inputs.push_back(all_tensors.at(curr_node->input_indices_.at(j)));
}
size_t output_nums = curr_node->output_indices_.size();
outputs.reserve(output_nums);
for (size_t j = 0; j < output_nums; ++j) {
outputs.push_back(all_tensors.at(curr_node->output_indices_.at(j)));
}
auto primitive = curr_node->primitive_;
MS_CHECK_PTR(primitive);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(primitive), schema_version);
if (parame_gen == nullptr) {
MS_LOG(ERROR) << "parameter generator is nullptr.";
return RET_NULL_PTR;
}
auto parameter = parame_gen(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << PrimitiveTypeName(GetPrimitiveType(primitive));
return RET_ERROR;
}
parameter->infer_flag_ = true;
auto ret = KernelInferShape(inputs, &outputs, parameter);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << curr_node->name_
<< ", type: " << PrimitiveTypeName(GetPrimitiveType(primitive)) << "flag set to false.";
parameter->infer_flag_ = false;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape failed, name: " << curr_node->name_
<< ", type: " << PrimitiveTypeName(GetPrimitiveType(primitive));
return RET_ERROR;
}
}
return RET_OK;
}
void CoderSession::EndCode() {
context_->set_tensor_map(allocator_->tensors_map());
context_->set_saved_weights(allocator_->saved_weights());
@ -188,7 +138,6 @@ int CoderSession::Init(const std::string &model_path) {
MS_CHECK_PTR(model);
coder_graph_ = std::make_unique<CoderGraph>(model);
context_ = std::make_unique<CoderContext>();
allocator_->RecordRuntimeAddrs(context_->input_name(), context_->buffer_name(), context_->weight_name());
MS_LOG(INFO) << "CoderSession::Init done";
return RET_OK;
}
@ -251,6 +200,29 @@ int CoderSession::InitTensorsRef() {
return RET_OK;
}
OpParameter *CoderSession::GenParameterAndInfer(const Model::Node *node, const std::vector<lite::Tensor *> &inputs,
std::vector<lite::Tensor *> *outputs) const {
auto primitive = node->primitive_;
MS_CHECK_PTR_RET_NULL(primitive);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(GetPrimitiveType(primitive), schema_version);
MS_CHECK_PTR_RET_NULL(parame_gen);
auto parameter = parame_gen(primitive);
MS_CHECK_PTR_RET_NULL(parameter);
parameter->infer_flag_ = true;
auto ret = KernelInferShape(inputs, outputs, parameter);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_
<< ", type: " << PrimitiveTypeName(GetPrimitiveType(primitive)) << "flag set to false.";
parameter->infer_flag_ = false;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "InferShape failed, name: " << node->name_
<< ", type: " << PrimitiveTypeName(GetPrimitiveType(primitive));
return nullptr;
}
return parameter;
}
int CoderSession::CreateOpCoders() {
const Model *model = coder_graph_->model();
if (model == nullptr) {
@ -308,11 +280,13 @@ int CoderSession::CreateOpCoders() {
MS_LOG(ERROR) << "node: " << node->name_ << "has no outputs tensor";
return RET_ERROR;
}
OpParameter *parameter = GenParameterAndInfer(node, inputs, &outputs);
MS_CHECK_PTR(parameter);
TypeId tensor_data_type = inputs.at(0)->data_type();
std::unique_ptr<OperatorCoder> op_coder = builder.inputs(inputs)
.outputs(outputs)
.node(node)
.parameter(parameter)
.target(code_target)
.support_parallel(support_parallel)
.data_type(tensor_data_type)
@ -337,7 +311,6 @@ int CoderSession::InitCodeGraph() {
int CoderSession::CompileGraph() {
MS_LOG(INFO) << "CompileGraph";
MS_CHECK_RET_CODE(InitCodeGraph(), "InitGraphInOutTensors failed");
MS_CHECK_RET_CODE(InferShape(), "do infershape failed!");
MS_CHECK_RET_CODE(CreateOpCoders(), "CreateOpCoders failed!");
MS_CHECK_RET_CODE(InitTensorsRef(), "InitTensorsRefcount failed!");
return RET_OK;

View File

@ -43,12 +43,13 @@ class CoderSession {
int GenerateCode();
private:
OpParameter *GenParameterAndInfer(const Model::Node *node, const std::vector<lite::Tensor *> &inputs,
std::vector<lite::Tensor *> *outputs) const;
int InitOpcodersInputsAndOutputs();
int InitTensorsRef();
int CreateOpCoders();
int InitCodeGraph();
int CompileGraph();
int InferShape();
void EndCode();
std::unique_ptr<CoderGraph> coder_graph_{nullptr};

View File

@ -55,6 +55,7 @@ std::string GetVariableTypeName() {
{std::type_index(typeid(double)), "double"},
{std::type_index(typeid(::QuantArg)), "QuantArg"},
{std::type_index(typeid(void *)), "void *"},
{std::type_index(typeid(std::string)), "float *"},
{std::type_index(typeid(int *)), "int *"},
{std::type_index(typeid(int32_t *)), "int32_t *"},
{std::type_index(typeid(int16_t *)), "int16_t *"},