forked from mindspore-Ecosystem/mindspore
!5022 delete GetPrimitiveT
Merge pull request !5022 from yeyunpeng2020/master
This commit is contained in:
commit
6763b63ca5
|
@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
return;
|
||||
}
|
||||
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
|
||||
primT->value.type == schema::PrimitiveType_MakeTuple ||
|
||||
primT->value.type == schema::PrimitiveType_Return) {
|
||||
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
|
||||
delete primT;
|
||||
primitiveT_value->SetPrimitiveT(nullptr);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||
// parse the model and weight file to generate inference data structure
|
||||
|
|
|
@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto type = primitiveT_value->GetPrimitiveT()->value.type;
|
||||
auto type = (schema::PrimitiveType)primitiveT_value->Type();
|
||||
MS_LOG(INFO) << "Primitive type: " << type;
|
||||
static const std::vector<schema::PrimitiveType> uint8OpList = {
|
||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||
|
|
|
@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
|||
if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
|
||||
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
|
||||
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
|
||||
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
|
||||
return a_value_node_ptr->Type() == b_value_node_ptr->Type();
|
||||
}
|
||||
|
||||
return a == b;
|
||||
|
@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
|||
if (utils::isa<PrimitiveCPtr>(value)) {
|
||||
auto primitive = value->cast<PrimitiveCPtr>();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
return primitive->GetPrimitiveT()->value.type;
|
||||
return (schema::PrimitiveType)primitive->Type();
|
||||
} else if (utils::isa<Primitive>(value)) {
|
||||
auto primitive = value->cast<PrimitivePtr>();
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
|
|
@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
|||
}
|
||||
return input_tensors;
|
||||
}
|
||||
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *lite_primitive = primitiveT_value->GetPrimitiveT();
|
||||
if (lite_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::Primitive::Pack(builder, lite_primitive);
|
||||
builder.Finish(offset);
|
||||
auto buf = builder.GetBufferPointer();
|
||||
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
|
||||
return const_cast<schema::Primitive *>(primitive);
|
||||
}
|
||||
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
||||
auto parameter = func_graph->add_parameter();
|
||||
std::vector<int> shape;
|
||||
|
@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
||||
auto output_nums = GetOutputTensorNum(input_cnode);
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
||||
auto scheam_primitive = PackPrimitiveT(input_cnode);
|
||||
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
|
||||
if (lite_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
|
||||
FreeInputTensor(&input_tensors);
|
||||
return nullptr;
|
||||
}
|
||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
|
||||
primitiveT_value->InferShape(input_tensors, output_tensors);
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
FreeInputTensor(&input_tensors);
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "src/ops/batch_norm.h"
|
||||
#include "src/ops/fused_batchnorm.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
|
@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
|||
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bn_other_var = std::make_shared<SeqVar>();
|
||||
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});;
|
||||
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});
|
||||
}
|
||||
// BatchNorm weight Tensor definition:
|
||||
// caffe
|
||||
|
@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
|||
// estimated_mean --2
|
||||
// estimated_variance --3
|
||||
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
float *trans_bias) const {
|
||||
MS_ASSERT(bn_node != nullptr);
|
||||
AnfNodePtr bn_mean_node = nullptr;
|
||||
AnfNodePtr bn_variance_node = nullptr;
|
||||
|
@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|||
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
||||
CheckIfNodeIsParam(bn_mean_node);
|
||||
CheckIfNodeIsParam(bn_variance_node);
|
||||
eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon;
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
eps = primc->GetEpsilon();
|
||||
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
|
||||
bn_scale_node = bn_node->input(kTFBNScaleIndex);
|
||||
bn_bias_node = bn_node->input(kTFBNBiasIndex);
|
||||
bn_mean_node = bn_node->input(kTFBNMeanIndex);
|
||||
bn_variance_node = bn_node->input(kTFBNVarIndex);
|
||||
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon;
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
eps = primc->GetEpsilon();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue