delete GetPrimitiveT in project

This commit is contained in:
yeyunpeng 2020-08-24 11:25:50 +08:00
parent e3899c552c
commit a095f72e03
5 changed files with 19 additions and 39 deletions

View File

@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
return;
}
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple ||
primT->value.type == schema::PrimitiveType_Return) {
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
delete primT;
primitiveT_value->SetPrimitiveT(nullptr);
}
}
return;
}
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// parse the model and weight file to generate inference data structure

View File

@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
return false;
}
auto type = primitiveT_value->GetPrimitiveT()->value.type;
auto type = (schema::PrimitiveType)primitiveT_value->Type();
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,

View File

@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
return a_value_node_ptr->Type() == b_value_node_ptr->Type();
}
return a == b;
@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
if (utils::isa<PrimitiveCPtr>(value)) {
auto primitive = value->cast<PrimitiveCPtr>();
MS_ASSERT(primitive != nullptr);
return primitive->GetPrimitiveT()->value.type;
return (schema::PrimitiveType)primitive->Type();
} else if (utils::isa<Primitive>(value)) {
auto primitive = value->cast<PrimitivePtr>();
MS_ASSERT(primitive != nullptr);

View File

@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
}
return input_tensors;
}
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
return nullptr;
}
auto *lite_primitive = primitiveT_value->GetPrimitiveT();
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr";
return nullptr;
}
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::Primitive::Pack(builder, lite_primitive);
builder.Finish(offset);
auto buf = builder.GetBufferPointer();
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
return const_cast<schema::Primitive *>(primitive);
}
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
auto parameter = func_graph->add_parameter();
std::vector<int> shape;
@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
auto output_nums = GetOutputTensorNum(input_cnode);
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
auto scheam_primitive = PackPrimitiveT(input_cnode);
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
FreeInputTensor(&input_tensors);
return nullptr;
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
primitiveT_value->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeInputTensor(&input_tensors);

View File

@ -22,6 +22,8 @@
#include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "src/ops/batch_norm.h"
#include "src/ops/fused_batchnorm.h"
namespace mindspore::opt {
namespace {
@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
auto bn_other_var = std::make_shared<SeqVar>();
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});;
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});
}
// BatchNorm weight Tensor definition:
// caffe
@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
CheckIfNodeIsParam(bn_mean_node);
CheckIfNodeIsParam(bn_variance_node);
eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon;
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
bn_scale_node = bn_node->input(kTFBNScaleIndex);
bn_bias_node = bn_node->input(kTFBNBiasIndex);
bn_mean_node = bn_node->input(kTFBNMeanIndex);
bn_variance_node = bn_node->input(kTFBNVarIndex);
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon;
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else {
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
}