forked from mindspore-Ecosystem/mindspore
!5022 delete GetPrimitiveT
Merge pull request !5022 from yeyunpeng2020/master
This commit is contained in:
commit
6763b63ca5
|
@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
|
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
|
||||||
primT->value.type == schema::PrimitiveType_MakeTuple ||
|
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
|
||||||
primT->value.type == schema::PrimitiveType_Return) {
|
|
||||||
delete primT;
|
delete primT;
|
||||||
primitiveT_value->SetPrimitiveT(nullptr);
|
primitiveT_value->SetPrimitiveT(nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||||
// parse the model and weight file to generate inference data structure
|
// parse the model and weight file to generate inference data structure
|
||||||
|
|
|
@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type = primitiveT_value->GetPrimitiveT()->value.type;
|
auto type = (schema::PrimitiveType)primitiveT_value->Type();
|
||||||
MS_LOG(INFO) << "Primitive type: " << type;
|
MS_LOG(INFO) << "Primitive type: " << type;
|
||||||
static const std::vector<schema::PrimitiveType> uint8OpList = {
|
static const std::vector<schema::PrimitiveType> uint8OpList = {
|
||||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||||
|
|
|
@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||||
if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
|
if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
|
||||||
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
|
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
|
||||||
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
|
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
|
||||||
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type;
|
return a_value_node_ptr->Type() == b_value_node_ptr->Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
return a == b;
|
return a == b;
|
||||||
|
@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
|
||||||
if (utils::isa<PrimitiveCPtr>(value)) {
|
if (utils::isa<PrimitiveCPtr>(value)) {
|
||||||
auto primitive = value->cast<PrimitiveCPtr>();
|
auto primitive = value->cast<PrimitiveCPtr>();
|
||||||
MS_ASSERT(primitive != nullptr);
|
MS_ASSERT(primitive != nullptr);
|
||||||
return primitive->GetPrimitiveT()->value.type;
|
return (schema::PrimitiveType)primitive->Type();
|
||||||
} else if (utils::isa<Primitive>(value)) {
|
} else if (utils::isa<Primitive>(value)) {
|
||||||
auto primitive = value->cast<PrimitivePtr>();
|
auto primitive = value->cast<PrimitivePtr>();
|
||||||
MS_ASSERT(primitive != nullptr);
|
MS_ASSERT(primitive != nullptr);
|
||||||
|
|
|
@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
||||||
}
|
}
|
||||||
return input_tensors;
|
return input_tensors;
|
||||||
}
|
}
|
||||||
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
|
|
||||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
||||||
if (primitiveT_value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *lite_primitive = primitiveT_value->GetPrimitiveT();
|
|
||||||
if (lite_primitive == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
flatbuffers::FlatBufferBuilder builder(1024);
|
|
||||||
auto offset = schema::Primitive::Pack(builder, lite_primitive);
|
|
||||||
builder.Finish(offset);
|
|
||||||
auto buf = builder.GetBufferPointer();
|
|
||||||
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
|
|
||||||
return const_cast<schema::Primitive *>(primitive);
|
|
||||||
}
|
|
||||||
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
|
||||||
auto parameter = func_graph->add_parameter();
|
auto parameter = func_graph->add_parameter();
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
|
@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
||||||
auto output_nums = GetOutputTensorNum(input_cnode);
|
auto output_nums = GetOutputTensorNum(input_cnode);
|
||||||
|
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||||
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
||||||
auto scheam_primitive = PackPrimitiveT(input_cnode);
|
primitiveT_value->InferShape(input_tensors, output_tensors);
|
||||||
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
|
||||||
if (lite_primitive == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
|
|
||||||
FreeInputTensor(&input_tensors);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
|
||||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
|
|
||||||
if (lite_kernel == nullptr) {
|
if (lite_kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||||
FreeInputTensor(&input_tensors);
|
FreeInputTensor(&input_tensors);
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
|
#include "src/ops/batch_norm.h"
|
||||||
|
#include "src/ops/fused_batchnorm.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
||||||
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
||||||
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
||||||
auto bn_other_var = std::make_shared<SeqVar>();
|
auto bn_other_var = std::make_shared<SeqVar>();
|
||||||
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});;
|
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});
|
||||||
}
|
}
|
||||||
// BatchNorm weight Tensor definition:
|
// BatchNorm weight Tensor definition:
|
||||||
// caffe
|
// caffe
|
||||||
|
@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
||||||
// estimated_mean --2
|
// estimated_mean --2
|
||||||
// estimated_variance --3
|
// estimated_variance --3
|
||||||
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
||||||
float *trans_bias) const {
|
float *trans_bias) const {
|
||||||
MS_ASSERT(bn_node != nullptr);
|
MS_ASSERT(bn_node != nullptr);
|
||||||
AnfNodePtr bn_mean_node = nullptr;
|
AnfNodePtr bn_mean_node = nullptr;
|
||||||
AnfNodePtr bn_variance_node = nullptr;
|
AnfNodePtr bn_variance_node = nullptr;
|
||||||
|
@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
||||||
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
||||||
CheckIfNodeIsParam(bn_mean_node);
|
CheckIfNodeIsParam(bn_mean_node);
|
||||||
CheckIfNodeIsParam(bn_variance_node);
|
CheckIfNodeIsParam(bn_variance_node);
|
||||||
eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon;
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
|
||||||
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
|
||||||
|
MS_ASSERT(primc != nullptr);
|
||||||
|
eps = primc->GetEpsilon();
|
||||||
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
|
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
|
||||||
bn_scale_node = bn_node->input(kTFBNScaleIndex);
|
bn_scale_node = bn_node->input(kTFBNScaleIndex);
|
||||||
bn_bias_node = bn_node->input(kTFBNBiasIndex);
|
bn_bias_node = bn_node->input(kTFBNBiasIndex);
|
||||||
bn_mean_node = bn_node->input(kTFBNMeanIndex);
|
bn_mean_node = bn_node->input(kTFBNMeanIndex);
|
||||||
bn_variance_node = bn_node->input(kTFBNVarIndex);
|
bn_variance_node = bn_node->input(kTFBNVarIndex);
|
||||||
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon;
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
|
||||||
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
|
||||||
|
MS_ASSERT(primc != nullptr);
|
||||||
|
eps = primc->GetEpsilon();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
|
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue