forked from mindspore-Ecosystem/mindspore
commit
f957e4f588
|
@ -24,21 +24,25 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
subgraph->name_ = sub_graph.name()->c_str();
|
||||
MS_ASSERT(sub_graph.inputIndices() != nullptr);
|
||||
auto in_count = sub_graph.inputIndices()->size();
|
||||
for (uint32_t i = 0; i < in_count; ++i) {
|
||||
subgraph->input_indices_.push_back(size_t(sub_graph.inputIndices()->GetAs<uint32_t>(i)));
|
||||
subgraph->input_indices_.push_back(sub_graph.inputIndices()->Get(i));
|
||||
}
|
||||
MS_ASSERT(sub_graph.outputIndices() != nullptr);
|
||||
auto out_count = sub_graph.outputIndices()->size();
|
||||
for (uint32_t i = 0; i < out_count; ++i) {
|
||||
subgraph->output_indices_.push_back(size_t(sub_graph.outputIndices()->GetAs<uint32_t>(i)));
|
||||
subgraph->output_indices_.push_back(sub_graph.outputIndices()->Get(i));
|
||||
}
|
||||
MS_ASSERT(sub_graph.nodeIndices() != nullptr);
|
||||
auto node_count = sub_graph.nodeIndices()->size();
|
||||
for (uint32_t i = 0; i < node_count; ++i) {
|
||||
subgraph->node_indices_.push_back(size_t(sub_graph.nodeIndices()->GetAs<uint32_t>(i)));
|
||||
subgraph->node_indices_.push_back(sub_graph.nodeIndices()->Get(i));
|
||||
}
|
||||
auto tensor_count = sub_graph.nodeIndices()->size();
|
||||
MS_ASSERT(sub_graph.tensorIndices() != nullptr);
|
||||
auto tensor_count = sub_graph.tensorIndices()->size();
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
subgraph->tensor_indices_.push_back(size_t(sub_graph.tensorIndices()->GetAs<uint32_t>(i)));
|
||||
subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i));
|
||||
}
|
||||
model->sub_graphs_.push_back(subgraph);
|
||||
return RET_OK;
|
||||
|
|
|
@ -860,9 +860,15 @@ ThreadPool *CreateThreadPool(int thread_num, int mode) {
|
|||
if (ret != RET_TP_OK) {
|
||||
LOG_ERROR("create thread %d failed", i);
|
||||
DestroyThreadPool(thread_pool);
|
||||
thread_pool = NULL;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
if (thread_pool == NULL) {
|
||||
LOG_ERROR("create thread failed");
|
||||
DestroyThreadPool(thread_pool);
|
||||
return NULL;
|
||||
}
|
||||
return thread_pool;
|
||||
}
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@ void *WorkspacePool::AllocWorkSpaceMem(size_t size) {
|
|||
}
|
||||
}
|
||||
allocList.emplace_back(alloc);
|
||||
return alloc.second;
|
||||
return alloc.second != nullptr ? alloc.second : nullptr;
|
||||
}
|
||||
|
||||
void WorkspacePool::FreeWorkSpaceMem(const void *ptr) {
|
||||
|
|
|
@ -120,6 +120,10 @@ int Benchmark::ReadInputFile() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto input_data = cur_tensor->MutableData();
|
||||
if (input_data == nullptr) {
|
||||
MS_LOG(ERROR) << "input_data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(input_data, bin_buf, tensor_data_size);
|
||||
}
|
||||
delete[] bin_buf;
|
||||
|
@ -232,7 +236,7 @@ int Benchmark::CompareOutput() {
|
|||
}
|
||||
float mean_bias;
|
||||
if (total_size != 0) {
|
||||
mean_bias = total_bias / total_size * 100;
|
||||
mean_bias = total_bias / float_t(total_size) * 100;
|
||||
} else {
|
||||
mean_bias = 0;
|
||||
}
|
||||
|
@ -286,21 +290,26 @@ int Benchmark::CompareStringData(const std::string &name, tensor::MSTensor *tens
|
|||
int Benchmark::CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias,
|
||||
int *total_size) {
|
||||
float bias = 0;
|
||||
auto mutableData = tensor->MutableData();
|
||||
if (mutableData == nullptr) {
|
||||
MS_LOG(ERROR) << "mutableData is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
switch (msCalibDataType) {
|
||||
case TypeId::kNumberTypeFloat: {
|
||||
bias = CompareData<float>(name, tensor->shape(), tensor->MutableData());
|
||||
bias = CompareData<float>(name, tensor->shape(), mutableData);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt8: {
|
||||
bias = CompareData<int8_t>(name, tensor->shape(), tensor->MutableData());
|
||||
bias = CompareData<int8_t>(name, tensor->shape(), mutableData);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt8: {
|
||||
bias = CompareData<uint8_t>(name, tensor->shape(), tensor->MutableData());
|
||||
bias = CompareData<uint8_t>(name, tensor->shape(), mutableData);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
bias = CompareData<int32_t>(name, tensor->shape(), tensor->MutableData());
|
||||
bias = CompareData<int32_t>(name, tensor->shape(), mutableData);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
@ -420,6 +429,10 @@ int Benchmark::PrintInputData() {
|
|||
}
|
||||
size_t print_num = std::min(input->ElementsNum(), 20);
|
||||
const void *in_data = input->MutableData();
|
||||
if (in_data == nullptr) {
|
||||
MS_LOG(ERROR) << "in_data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < print_num; j++) {
|
||||
if (tensor_data_type == TypeId::kNumberTypeFloat32 || tensor_data_type == TypeId::kNumberTypeFloat) {
|
||||
|
@ -723,7 +736,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
|
|||
}
|
||||
columns.push_back(iter.first);
|
||||
|
||||
len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / flags_->loop_count_);
|
||||
len = snprintf(stringBuf[1], sizeof(stringBuf[1]), "%f", iter.second.second / float_t(flags_->loop_count_));
|
||||
if (len > columnLenMax.at(1)) {
|
||||
columnLenMax.at(1) = len + 4;
|
||||
}
|
||||
|
@ -760,9 +773,9 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
|
|||
printf("%s\t", printBuf.c_str());
|
||||
}
|
||||
printf("\n");
|
||||
for (size_t i = 0; i < rows.size(); i++) {
|
||||
for (auto &row : rows) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
auto printBuf = rows[i][j];
|
||||
auto printBuf = row[j];
|
||||
printBuf.resize(columnLenMax.at(j), ' ');
|
||||
printf("%s\t", printBuf.c_str());
|
||||
}
|
||||
|
@ -772,7 +785,7 @@ int Benchmark::PrintResult(const std::vector<std::string> &title,
|
|||
}
|
||||
|
||||
Benchmark::~Benchmark() {
|
||||
for (auto iter : this->benchmark_data_) {
|
||||
for (const auto &iter : this->benchmark_data_) {
|
||||
delete (iter.second);
|
||||
}
|
||||
this->benchmark_data_.clear();
|
||||
|
|
|
@ -88,24 +88,24 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
std::string model_file_;
|
||||
std::string in_data_file_;
|
||||
std::vector<std::string> input_data_list_;
|
||||
InDataType in_data_type_;
|
||||
InDataType in_data_type_ = kBinary;
|
||||
std::string in_data_type_in_ = "bin";
|
||||
int cpu_bind_mode_ = 1;
|
||||
// MarkPerformance
|
||||
int loop_count_;
|
||||
int num_threads_;
|
||||
bool enable_fp16_;
|
||||
int warm_up_loop_count_;
|
||||
bool time_profiling_;
|
||||
int loop_count_ = 10;
|
||||
int num_threads_ = 2;
|
||||
bool enable_fp16_ = false;
|
||||
int warm_up_loop_count_ = 3;
|
||||
bool time_profiling_ = false;
|
||||
// MarkAccuracy
|
||||
std::string benchmark_data_file_;
|
||||
std::string benchmark_data_type_;
|
||||
float accuracy_threshold_;
|
||||
std::string benchmark_data_type_ = "FLOAT";
|
||||
float accuracy_threshold_ = 0.5;
|
||||
// Resize
|
||||
std::string resize_dims_in_ = "";
|
||||
std::string resize_dims_in_;
|
||||
std::vector<std::vector<int>> resize_dims_;
|
||||
|
||||
std::string device_;
|
||||
std::string device_ = "CPU";
|
||||
};
|
||||
|
||||
class MS_API Benchmark {
|
||||
|
@ -149,7 +149,7 @@ class MS_API Benchmark {
|
|||
|
||||
// tensorData need to be converter first
|
||||
template <typename T>
|
||||
float CompareData(const std::string &nodeName, std::vector<int> msShape, const void *tensor_data) {
|
||||
float CompareData(const std::string &nodeName, const std::vector<int> &msShape, const void *tensor_data) {
|
||||
const T *msTensorData = static_cast<const T *>(tensor_data);
|
||||
auto iter = this->benchmark_data_.find(nodeName);
|
||||
if (iter != this->benchmark_data_.end()) {
|
||||
|
|
|
@ -33,9 +33,9 @@ struct Nothing {};
|
|||
|
||||
class FlagParser {
|
||||
public:
|
||||
FlagParser() { AddFlag(&FlagParser::help, "help", "print usage message", ""); }
|
||||
FlagParser() { AddFlag(&FlagParser::help, helpStr, "print usage message", ""); }
|
||||
|
||||
virtual ~FlagParser() {}
|
||||
virtual ~FlagParser() = default;
|
||||
|
||||
// only support read flags from command line
|
||||
virtual Option<std::string> ParseFlags(int argc, const char *const *argv, bool supportUnknown = false,
|
||||
|
@ -60,7 +60,7 @@ class FlagParser {
|
|||
// Option-type fields
|
||||
template <typename Flags, typename T>
|
||||
void AddFlag(Option<T> Flags::*t, const std::string &flagName, const std::string &helpInfo);
|
||||
bool help;
|
||||
bool help{};
|
||||
|
||||
protected:
|
||||
template <typename Flags>
|
||||
|
@ -70,14 +70,15 @@ class FlagParser {
|
|||
|
||||
std::string binName;
|
||||
Option<std::string> usageMsg;
|
||||
std::string helpStr = "help";
|
||||
|
||||
private:
|
||||
struct FlagInfo {
|
||||
std::string flagName;
|
||||
bool isRequired;
|
||||
bool isBoolean;
|
||||
bool isRequired = false;
|
||||
bool isBoolean = false;
|
||||
std::string helpInfo;
|
||||
bool isParsed;
|
||||
bool isParsed = false;
|
||||
std::function<Option<Nothing>(FlagParser *, const std::string &)> parse;
|
||||
};
|
||||
|
||||
|
@ -218,7 +219,7 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::
|
|||
return;
|
||||
}
|
||||
|
||||
Flags *flag = dynamic_cast<Flags *>(this);
|
||||
auto *flag = dynamic_cast<Flags *>(this);
|
||||
if (flag == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -228,7 +229,10 @@ void FlagParser::AddFlag(T1 Flags::*t1, const std::string &flagName, const std::
|
|||
// flagItem is as a output parameter
|
||||
ConstructFlag(t1, flagName, helpInfo, &flagItem);
|
||||
flagItem.parse = [t1](FlagParser *base, const std::string &value) -> Option<Nothing> {
|
||||
Flags *flag = dynamic_cast<Flags *>(base);
|
||||
auto *flag = dynamic_cast<Flags *>(base);
|
||||
if (flag == nullptr) {
|
||||
return Option<Nothing>(None());
|
||||
}
|
||||
if (base != nullptr) {
|
||||
Option<T1> ret = Option<T1>(GenericParseValue<T1>(value));
|
||||
if (ret.IsNone()) {
|
||||
|
@ -267,7 +271,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const
|
|||
return;
|
||||
}
|
||||
|
||||
Flags *flag = dynamic_cast<Flags *>(this);
|
||||
auto *flag = dynamic_cast<Flags *>(this);
|
||||
if (flag == nullptr) {
|
||||
MS_LOG(ERROR) << "dynamic_cast failed";
|
||||
return;
|
||||
|
@ -278,7 +282,7 @@ void FlagParser::AddFlag(Option<T> Flags::*t, const std::string &flagName, const
|
|||
ConstructFlag(t, flagName, helpInfo, &flagItem);
|
||||
flagItem.isRequired = false;
|
||||
flagItem.parse = [t](FlagParser *base, const std::string &value) -> Option<Nothing> {
|
||||
Flags *flag = dynamic_cast<Flags *>(base);
|
||||
auto *flag = dynamic_cast<Flags *>(base);
|
||||
if (base != nullptr) {
|
||||
Option<T> ret = Option<std::string>(GenericParseValue<T>(value));
|
||||
if (ret.IsNone()) {
|
||||
|
|
|
@ -605,10 +605,6 @@ std::string GetModelName(const std::string &modelFile) {
|
|||
std::string modelName = modelFile;
|
||||
modelName = modelName.substr(modelName.find_last_of('/') + 1);
|
||||
modelName = modelName.substr(0, modelName.find_last_of('.'));
|
||||
|
||||
srand((unsigned)time(NULL));
|
||||
modelName = modelName + std::to_string(rand());
|
||||
|
||||
return modelName;
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -101,10 +101,11 @@ STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &p
|
|||
}
|
||||
fcAttr->hasBias = true;
|
||||
fcAttr->axis = 1;
|
||||
MS_ASSERT(matMulNode->primitive != nullptr);
|
||||
MS_ASSERT(matMulNode->primitive->value != nullptr);
|
||||
MS_ASSERT(matMulNode->primitive->value.AsMatMul() != nullptr);
|
||||
transA = matMulNode->primitive->value.AsMatMul()->transposeA;
|
||||
transB = matMulNode->primitive->value.AsMatMul()->transposeB;
|
||||
MS_ASSERT(matMulNode->primitive->value.value != nullptr);
|
||||
matMulNode->primitive->value.type = schema::PrimitiveType_FullConnection;
|
||||
matMulNode->primitive->value.value = fcAttr.release();
|
||||
|
||||
|
|
|
@ -146,6 +146,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
|
|||
int shape_size = graph->allTensors.at(addBiasIndex)->dims.size();
|
||||
scaleParam->axis = 0 - shape_size;
|
||||
mulNode->inputIndex.push_back(addBiasIndex);
|
||||
MS_ASSERT(addNode->primitive != nullptr);
|
||||
MS_ASSERT(addNode->primitive->value != nullptr);
|
||||
MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr);
|
||||
auto activationType = addNode->primitive->value.AsAdd()->activationType;
|
||||
if (activationType == ActivationType_RELU || activationType == ActivationType_RELU6 ||
|
||||
activationType == ActivationType_NO_ACTIVATION) {
|
||||
|
@ -159,6 +162,9 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
|
|||
} else {
|
||||
// repace addnode as activation
|
||||
std::unique_ptr<ActivationT> activationParam(new ActivationT());
|
||||
MS_ASSERT(addNode->primitive != nullptr);
|
||||
MS_ASSERT(addNode->primitive->value != nullptr);
|
||||
MS_ASSERT(addNode->primitive->value.AsAdd() != nullptr);
|
||||
activationParam->type = addNode->primitive->value.AsAdd()->activationType;
|
||||
addNode->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
addNode->primitive->value.value = activationParam.release();
|
||||
|
|
|
@ -91,6 +91,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|||
if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(node->primitive != nullptr);
|
||||
MS_ASSERT(node->primitive->value != nullptr);
|
||||
MS_ASSERT(node->primitive->value.AsActivation() != nullptr);
|
||||
if (node->primitive->value.AsActivation() != nullptr &&
|
||||
node->primitive->value.AsActivation()->type == schema::ActivationType_LEAKY_RELU) {
|
||||
return has_trans_count >= half_count;
|
||||
|
@ -198,6 +200,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|||
MS_LOG(ERROR) << "node or primitive null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
MS_ASSERT(node->primitive->value != nullptr);
|
||||
auto type = node->primitive->value.type;
|
||||
auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
|
||||
if (input1_ndim != 4) {
|
||||
|
@ -213,6 +216,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|||
}
|
||||
}
|
||||
if (type == PrimitiveType_Concat) {
|
||||
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsConcat() == nullptr) {
|
||||
|
@ -222,6 +226,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|||
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
||||
}
|
||||
if (type == PrimitiveType_Split) {
|
||||
MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
if (node->primitive->value.AsSplit() == nullptr) {
|
||||
|
@ -231,6 +236,7 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|||
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
||||
}
|
||||
if (type == PrimitiveType_Crop) {
|
||||
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
|
||||
auto origin_axis = node->primitive->value.AsCrop()->axis;
|
||||
auto offsets = node->primitive->value.AsCrop()->offsets;
|
||||
auto axis_map = GetNc2NhAxisMap();
|
||||
|
|
|
@ -76,6 +76,10 @@ schema::TensorT *ConvertWeight(const caffe::BlobProto &proto) {
|
|||
}
|
||||
weight->data.resize(count * sizeof(float));
|
||||
const float *data_ptr = proto.data().data();
|
||||
if (data_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "data_ptr is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (::memcpy_s(weight->data.data(), count * sizeof(float), (uint8_t *)data_ptr, count * sizeof(float)) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -157,6 +157,9 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
auto iter = std::find_if((*nodeIter).attribute().begin(), (*nodeIter).attribute().end(),
|
||||
[](const onnx::AttributeProto &attr) { return attr.name() == "shape"; });
|
||||
if (iter != (*nodeIter).attribute().end()) {
|
||||
MS_ASSERT(iter->ints() != nullptr);
|
||||
MS_ASSERT(iter->ints().begin() != nullptr);
|
||||
MS_ASSERT(iter->ints().end() != nullptr);
|
||||
dims.insert(dims.begin(), iter->ints().begin(), iter->ints().end());
|
||||
}
|
||||
attr->channelOut = dims[0];
|
||||
|
|
|
@ -40,6 +40,7 @@ STATUS OnnxLpNormParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N
|
|||
|
||||
auto onnx_node_attr = onnx_node.attribute();
|
||||
for (int i = 0; i < onnx_node_attr.size(); ++i) {
|
||||
MS_ASSERT(onnx_node_attr.at(i) != nullptr);
|
||||
if (onnx_node_attr.at(i).name() == "axis") {
|
||||
attr->axis = onnx_node_attr.at(i).i();
|
||||
} else if (onnx_node_attr.at(i).name() == "p") {
|
||||
|
|
|
@ -40,6 +40,7 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
auto onnx_node_attr = onnx_node.attribute();
|
||||
int32_t size = 0;
|
||||
for (int i = 0; i < onnx_node_attr.size(); ++i) {
|
||||
MS_ASSERT(onnx_node_attr.at(i) != nullptr);
|
||||
if (onnx_node_attr.at(i).name() == "alpha") {
|
||||
attr->alpha = onnx_node_attr.at(i).f();
|
||||
} else if (onnx_node_attr.at(i).name() == "beta") {
|
||||
|
@ -51,6 +52,11 @@ STATUS OnnxLrnParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Node
|
|||
attr->depth_radius = size / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (size == 0) {
|
||||
MS_LOG(ERROR) << "Divide-by-zero error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->alpha /= size;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization;
|
||||
|
|
|
@ -240,6 +240,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||
auto primitive = lite_primitive.get();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
MS_ASSERT(primitive->Type() != nullptr);
|
||||
auto parameter =
|
||||
lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
|
||||
|
||||
|
|
|
@ -67,8 +67,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
}
|
||||
// transform node means scale,bn
|
||||
auto transform_node = node->cast<CNodePtr>();
|
||||
if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK ||
|
||||
CheckLeastInputSize(transform_node, 2) != lite::RET_OK) {
|
||||
if (CheckIfCNodeIsNull(transform_node) != lite::RET_OK || CheckLeastInputSize(transform_node, 2) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -93,6 +92,7 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
auto trans_bias = new (std::nothrow) float[kernel_nums];
|
||||
if (trans_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr";
|
||||
delete[] trans_scale;
|
||||
delete[] trans_bias;
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -234,8 +234,11 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
|
|||
return;
|
||||
}
|
||||
|
||||
delete[] tmp_weight_data;
|
||||
if (tmp_weight_data != nullptr) {
|
||||
delete[] tmp_weight_data;
|
||||
}
|
||||
}
|
||||
|
||||
const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag,
|
||||
const float *trans_scale, const float *trans_bias) const {
|
||||
MS_ASSERT(bias_data != nullptr);
|
||||
|
|
|
@ -56,6 +56,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(primT->value != nullptr);
|
||||
MS_ASSERT(primT->value.AsTranspose() != nullptr);
|
||||
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
|
||||
if (perm == kPermNCHW) {
|
||||
manager->Replace(transpose_cnode, transpose_cnode->input(1));
|
||||
|
@ -77,6 +79,8 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(primT->value != nullptr);
|
||||
MS_ASSERT(primT->value.AsTranspose() != nullptr);
|
||||
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
|
||||
if (perm == kPermNHWC) {
|
||||
manager->Replace(transpose_cnode, transpose_cnode->input(1));
|
||||
|
|
Loading…
Reference in New Issue