forked from mindspore-Ecosystem/mindspore
!8807 [MS][LITE] fix static check error
From: @jianghui58 Reviewed-by: @zhanghaibo5,@HilbertDavid Signed-off-by: @zhanghaibo5,@HilbertDavid
This commit is contained in:
commit
e6baa0b25e
|
@ -391,6 +391,10 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
if (value->isa<tensor::Tensor>()) {
|
||||
auto valueAbstract = valueNode->abstract();
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract);
|
||||
if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
|
||||
MS_LOG(ERROR) << "abstractTensor or abstractTensor->element() is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
||||
paramTensor->dataType = typePtr->type_id();
|
||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
||||
|
@ -404,7 +408,11 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
paramTensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
auto data = value->cast<tensor::TensorPtr>();
|
||||
paramTensor->data.resize(data->Size());
|
||||
memcpy(paramTensor->data.data(), data->data_c(), data->Size());
|
||||
auto ret = memcpy_s(paramTensor->data.data(), data->Size(), data->data_c(), data->Size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
||||
|
@ -417,7 +425,11 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
|
|||
paramTensor->nodeType = schema::NodeType::NodeType_ValueNode;
|
||||
int real_data = CastToInt(value, false).front();
|
||||
paramTensor->data.resize(sizeof(int32_t));
|
||||
memcpy(paramTensor->data.data(), &real_data, sizeof(int32_t));
|
||||
auto ret = memcpy_s(paramTensor->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
||||
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
||||
|
@ -526,6 +538,10 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
|
||||
for (size_t i = 0; i < tuple->size(); i++) {
|
||||
auto msTensor = new (std::nothrow) schema::TensorT();
|
||||
if (msTensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new msTensor failed";
|
||||
return;
|
||||
}
|
||||
msTensor->nodeType = schema::NodeType_CNode;
|
||||
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|
@ -553,6 +569,10 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
|
|||
}
|
||||
} else {
|
||||
auto ms_tensor = new (std::nothrow) schema::TensorT();
|
||||
if (ms_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new tensor failed";
|
||||
return;
|
||||
}
|
||||
ms_tensor->nodeType = schema::NodeType_CNode;
|
||||
ms_tensor->dataType = TypeId::kNumberTypeFloat32;
|
||||
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
|
||||
|
|
|
@ -154,6 +154,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
{
|
||||
Optimizer quantNodeOptimizer;
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
MS_LOG(ERROR) << "new dTypeTransPass failed";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
dTypeTransPass->SetInputDataDType(ctx.inputDataType);
|
||||
dTypeTransPass->SetOutputDataDType(ctx.outputDataType);
|
||||
quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
|
|
|
@ -331,9 +331,17 @@ STATUS BatchNormFoldFusionPass::GenNewWeightTensor() {
|
|||
MS_ASSERT(muTensor->dataType == DataType_DT_FLOAT);
|
||||
void *miData = muTensor->data.data();
|
||||
auto *castedMiData = static_cast<float *>(miData);
|
||||
if (channelOut == 0) {
|
||||
MS_LOG(ERROR) << "divisor 'channelOut' cannot be 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t stride = weightShapeSize / channelOut;
|
||||
for (int i = 0; i < channelOut; i++) {
|
||||
for (size_t j = 0; j < stride; j++) {
|
||||
if (fabs(castedMiData[i]) <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor 'castedMiData' cannot be 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
castedNewWeightData[i * stride + j] = castedOldWeightData[i * stride + j] * castedGammaData[i] / castedMiData[i];
|
||||
}
|
||||
}
|
||||
|
@ -367,6 +375,10 @@ STATUS BatchNormFoldFusionPass::GenNewBiasTensor() { // bias has no quant
|
|||
void *sigmaData = sigmaTensor->data.data();
|
||||
auto *castedSigmaData = static_cast<float *>(sigmaData);
|
||||
for (int i = 0; i < channelOut; i++) {
|
||||
if (fabs(castedSigmaData[i]) <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor 'castedSigmaData' cannot be 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
castedNewBiasData[i] = castedBetaData[i] - castedGammaData[i] * castedMiData[i] / castedSigmaData[i];
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -83,5 +83,4 @@ class BatchNormFoldFusionPass : public FusionPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_BATCHNORM_FOLD_FUSION_PASS_H
|
||||
|
|
|
@ -63,7 +63,8 @@ STATUS FormatTransFusionPass::DefinePattern() {
|
|||
|
||||
passOp->left = nc2nhOp;
|
||||
nh2ncOp->left = passOp;
|
||||
std::unique_ptr<FusionPattern> nc2NhAndNh2NcPassFusionPattern(new FusionPattern(kNc2NhAndNh2NcPassFusionPattern));
|
||||
std::unique_ptr<FusionPattern> nc2NhAndNh2NcPassFusionPattern(new (std::nothrow)
|
||||
FusionPattern(kNc2NhAndNh2NcPassFusionPattern));
|
||||
if (nc2NhAndNh2NcPassFusionPattern == nullptr) {
|
||||
MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcPassFusionPattern << "failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS FusionPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto ret = DefinePattern();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DefinePattern Error " << ret;
|
||||
|
|
|
@ -94,7 +94,7 @@ STATUS MatMulBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &p
|
|||
|
||||
// 2. change matmul to full connection op
|
||||
matMulNode->name += "-fc";
|
||||
std::unique_ptr<FullConnectionT> fcAttr(new FullConnectionT());
|
||||
std::unique_ptr<FullConnectionT> fcAttr(new (std::nothrow) FullConnectionT());
|
||||
if (fcAttr == nullptr) {
|
||||
MS_LOG(ERROR) << "new FullConnectionT node failed";
|
||||
return RET_ERROR;
|
||||
|
@ -151,7 +151,7 @@ STATUS MatMulBiasAddFusionPass::InsertTransposeNode(MetaGraphT *graph, const std
|
|||
}
|
||||
transNode->name = "transpose" + std::to_string(id++);
|
||||
transNode->primitive->value.type = schema::PrimitiveType_Transpose;
|
||||
std::unique_ptr<TransposeT> transposeParam(new TransposeT());
|
||||
std::unique_ptr<TransposeT> transposeParam(new (std::nothrow) TransposeT());
|
||||
if (transposeParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new transposeParam failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -137,7 +137,7 @@ STATUS MulAddFusionPass::AddNewScaleNode(MetaGraphT *graph, const std::unique_pt
|
|||
MS_ASSERT(addNode != nullptr);
|
||||
// replace mulNode as scale
|
||||
mulNode->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
std::unique_ptr<ScaleT> scaleParam(new ScaleT());
|
||||
std::unique_ptr<ScaleT> scaleParam(new (std::nothrow) ScaleT());
|
||||
if (scaleParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new transposeParam failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -69,9 +69,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
|
|||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
bnNode->primitive->value.type = schema::PrimitiveType_Scale;
|
||||
std::unique_ptr<ScaleT> scaleParam(new ScaleT());
|
||||
std::unique_ptr<ScaleT> scaleParam(new (std::nothrow) ScaleT());
|
||||
if (scaleParam == nullptr) {
|
||||
MS_LOG(ERROR) << "new transposeParam failed";
|
||||
MS_LOG(ERROR) << "new scaleParam failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
scaleParam->axis = NCHW_DIM_C;
|
||||
|
@ -104,7 +104,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|||
newScaleWeightTensor->data.resize(weightShapeSize * sizeof(float));
|
||||
auto ret = memcpy_s(newScaleWeightTensor->data.data(), weightShapeSize * sizeof(float), transScale,
|
||||
weightShapeSize * sizeof(float));
|
||||
if (ret != RET_OK) {
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
|
@ -127,7 +127,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|||
newScaleBiasTensor->data.resize(weightShapeSize * sizeof(float));
|
||||
ret = memcpy_s(newScaleBiasTensor->data.data(), weightShapeSize * sizeof(float), transBias,
|
||||
weightShapeSize * sizeof(float));
|
||||
if (ret != RET_OK) {
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
|
@ -166,9 +166,17 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un
|
|||
return status;
|
||||
}
|
||||
this->transScale = new (std::nothrow) float[bnChannel];
|
||||
if (this->transScale == nullptr) {
|
||||
MS_LOG(ERROR) << "new transScale failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->transBias = new (std::nothrow) float[bnChannel];
|
||||
if (this->transBias == nullptr) {
|
||||
MS_LOG(ERROR) << "new transBias failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
|
||||
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) {
|
||||
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s transScale error";
|
||||
delete[] transScale;
|
||||
delete[] transBias;
|
||||
|
@ -180,6 +188,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un
|
|||
for (uint32_t i = 0; i < bnChannel; i++) {
|
||||
float tmp = transScale[i] + eps;
|
||||
tmp = pow(tmp, POW_NUM);
|
||||
if (tmp <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor 'tmp' cannot be 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
transScale[i] = 1 / tmp;
|
||||
}
|
||||
|
||||
|
@ -278,6 +290,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
|
|||
STATUS BatchNormConvertScalePass::GetBnEpsilon(const std::unique_ptr<CNodeT> &bnNode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(bnNode != nullptr);
|
||||
MS_ASSERT(bnNode->primitive != nullptr);
|
||||
if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) {
|
||||
eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon;
|
||||
} else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) {
|
||||
|
|
|
@ -32,6 +32,10 @@ STATUS IsolateDropoutNode(schema::MetaGraphT *graphT, size_t nodeIdx) {
|
|||
}
|
||||
|
||||
CNodeT *node = graphT->nodes.at(nodeIdx).get();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "node is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inputTensorIdxes = node->inputIndex;
|
||||
auto outputTensorIdxes = node->outputIndex;
|
||||
auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
|
||||
|
@ -103,6 +107,10 @@ STATUS DropoutNodeRemovePass::Run(schema::MetaGraphT *graph) {
|
|||
bool ifChanged = false;
|
||||
for (size_t i = 0; i < graph->nodes.size(); i++) {
|
||||
auto &node = graph->nodes.at(i);
|
||||
if (node->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "node->primitive is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (node->primitive->value.type == schema::PrimitiveType_Dropout) {
|
||||
ifChanged = true;
|
||||
auto status = IsolateDropoutNode(graph, i);
|
||||
|
|
|
@ -79,5 +79,4 @@ class DTypeTransPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
|
||||
|
|
|
@ -186,6 +186,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
|
||||
size_t inoutIdx, FormatTransNodeType nodeType, STATUS *errorCode) {
|
||||
MS_ASSERT((*existNodeIter) != nullptr);
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto existNodeName = (*existNodeIter)->name;
|
||||
std::string tileName;
|
||||
if (place == kBefore) {
|
||||
|
|
|
@ -55,5 +55,4 @@ class FormatTransPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H
|
||||
|
|
|
@ -105,6 +105,7 @@ STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) {
|
|||
return RET_OK;
|
||||
}
|
||||
STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (pre_not_trans_nodes.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
|
||||
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
|
@ -58,6 +59,7 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
|||
}
|
||||
|
||||
void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
bool canQuant = true;
|
||||
for (auto &inputTensorIdx : cnode->inputIndex) {
|
||||
|
|
|
@ -35,5 +35,4 @@ class InferQuantParamPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_INFER_QUANT_PARAM_PASS_H
|
||||
|
|
|
@ -31,6 +31,7 @@ constexpr int DEFAULT_DIM_VALUE = -1;
|
|||
namespace {
|
||||
std::vector<Tensor *> ConvertTensorToLiteTensor(MetaGraphT *graph, const std::vector<uint32_t> &tensor_indexs,
|
||||
const schema::PrimitiveType node_type) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::vector<Tensor *> lite_tensors;
|
||||
for (size_t i = 0; i < tensor_indexs.size(); i++) {
|
||||
auto &tensorT = graph->allTensors.at(tensor_indexs[i]);
|
||||
|
|
|
@ -32,5 +32,4 @@ class IsolatedNodeRemovePass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
for (auto &tensor : graph->allTensors) {
|
||||
for (auto &quant_param : tensor->quantParams) {
|
||||
quant_param->min = 0;
|
||||
|
@ -29,5 +30,4 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -30,5 +30,4 @@ class SetUnusedQuantParamToDefaultPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_UNUSED_QUANT_PARAM_DATA_REMOVE_PASS_H
|
||||
|
|
|
@ -23,7 +23,12 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
for (auto &node : graph->nodes) {
|
||||
if (node == nullptr || node->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << " node or node->primitive is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (node->primitive->value.type == PrimitiveType_QuantDTypeCast) {
|
||||
auto attr = node->primitive->value.AsQuantDTypeCast();
|
||||
auto &inputTensor = graph->allTensors.at(node->inputIndex.front());
|
||||
|
@ -97,6 +102,10 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
|
|||
}
|
||||
void *biasData = tensor->data.data();
|
||||
auto *rawDatas = static_cast<float *>(biasData);
|
||||
if (fabs(quantParam->scale) <= 0.0f) {
|
||||
MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bShapeSize; ++i) {
|
||||
qDatas[i] = (int32_t)std::round(rawDatas[i] / quantParam->scale);
|
||||
}
|
||||
|
@ -117,5 +126,4 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -32,5 +32,4 @@ class TensorQuantPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TENSOR_QUANT_PASS_H
|
||||
|
|
|
@ -70,6 +70,7 @@ STATUS TopologicalSortPass::Run(schema::MetaGraphT *graph) {
|
|||
|
||||
bool TopologicalSortPass::IsNodeNonDepend(const std::unique_ptr<schema::CNodeT> &node,
|
||||
const std::vector<size_t> &sinkedTensorIdxes) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
for (auto inputIdx : node->inputIndex) {
|
||||
if (!IsContain(sinkedTensorIdxes, size_t(inputIdx))) {
|
||||
return false;
|
||||
|
|
|
@ -37,5 +37,4 @@ class TopologicalSortPass : public GraphPass {
|
|||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_PREDICT_TOPOLOGICAL_SORT_PASS_H
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
|
||||
pre_type_ = schema::PrimitiveType_NONE;
|
||||
size_t has_trans_count = 0;
|
||||
|
@ -34,6 +36,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|||
MS_ASSERT(graph->nodes.size() > input_node_index);
|
||||
auto &pre_node = graph->nodes.at(input_node_index);
|
||||
MS_ASSERT(pre_node != nullptr);
|
||||
MS_ASSERT(pre_node->primitive != nullptr);
|
||||
MS_ASSERT(pre_node->primitive->value != nullptr);
|
||||
if (pre_type_ == schema::PrimitiveType_NONE) {
|
||||
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
|
@ -61,6 +65,8 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
|
|||
MS_ASSERT(graph->nodes.size() > output_node_index);
|
||||
auto &post_node = graph->nodes.at(output_node_index);
|
||||
MS_ASSERT(post_node != nullptr);
|
||||
MS_ASSERT(post_node->primitive != nullptr);
|
||||
MS_ASSERT(post_node->primitive->value != nullptr);
|
||||
if (post_type_ == schema::PrimitiveType_NONE) {
|
||||
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
|
||||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
|
||||
|
|
|
@ -46,12 +46,13 @@ void Optimizer::AddPass(NodePass *nodePass) {
|
|||
}
|
||||
|
||||
STATUS Optimizer::Run(schema::MetaGraphT *graphDefT) {
|
||||
MS_ASSERT(graphDefT != nullptr);
|
||||
STATUS status;
|
||||
bool ifNotChanged = true;
|
||||
// each node should go through all node pass not each node pass go through all node
|
||||
for (auto &opDef : graphDefT->nodes) {
|
||||
for (auto pass : this->nodePasses) {
|
||||
status = pass->Run(new GraphNode(graphDefT, opDef.get()));
|
||||
status = pass->Run(new (std::nothrow) GraphNode(graphDefT, opDef.get()));
|
||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Run NodePass failed";
|
||||
return status;
|
||||
|
|
Loading…
Reference in New Issue