forked from mindspore-Ecosystem/mindspore
!5040 rename primitiveTvalue to primitive_c
Merge pull request !5040 from yeyunpeng2020/master
This commit is contained in:
commit
3162b12552
|
@ -165,14 +165,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto primT = primitiveT_value->GetPrimitiveT();
|
||||
if (primitiveT_value->Type() == schema::PrimitiveType_TupleGetItem ||
|
||||
primitiveT_value->Type() == schema::PrimitiveType_MakeTuple) {
|
||||
auto primT = primitive_c->GetPrimitiveT();
|
||||
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
|
||||
primitive_c->Type() == schema::PrimitiveType_MakeTuple) {
|
||||
continue;
|
||||
}
|
||||
RemoveIfMakeTuple(cnode);
|
||||
|
@ -196,7 +196,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
|
|||
return nullptr;
|
||||
}
|
||||
SetOpOutputNode(cnode, meta_graphT, node.get());
|
||||
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
|
||||
ret = ConvertQuantParam(meta_graphT, primitive_c, node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ConvertQuantParam failed";
|
||||
return nullptr;
|
||||
|
|
|
@ -62,12 +62,12 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
MS_ASSERT(func_graph != nullptr);
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return;
|
||||
}
|
||||
auto primT = primitiveT_value->GetPrimitiveT();
|
||||
auto primT = primitive_c->GetPrimitiveT();
|
||||
if (primT == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT is nullptr";
|
||||
return;
|
||||
|
@ -75,7 +75,7 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
|
||||
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
|
||||
delete primT;
|
||||
primitiveT_value->SetPrimitiveT(nullptr);
|
||||
primitive_c->SetPrimitiveT(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value,
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
|
||||
bool perchanel, bool depthwise) {
|
||||
// const vector<int> dims = filter->dims;
|
||||
// perlayer
|
||||
|
@ -552,7 +552,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||
auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||
perchanel, depthwise);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
|
@ -573,8 +573,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value) {
|
||||
if (primitiveT_value == nullptr || bias == nullptr) {
|
||||
STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c) {
|
||||
if (primitive_c == nullptr || bias == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer!";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -583,7 +583,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
|||
auto bias_default_param = bias_parameter_ptr->default_param();
|
||||
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);
|
||||
|
||||
auto active_weight_quant_params = primitiveT_value->GetInputQuantParams();
|
||||
auto active_weight_quant_params = primitive_c->GetInputQuantParams();
|
||||
if (active_weight_quant_params.size() != 2) {
|
||||
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
|
||||
return RET_ERROR;
|
||||
|
@ -627,7 +627,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
|||
quant_param.inited = true;
|
||||
quant_params.emplace_back(quant_param);
|
||||
}
|
||||
primitiveT_value->AddInputQuantParam(quant_params);
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
// quant bias data
|
||||
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
|
||||
if (quant_datas == nullptr) {
|
||||
|
@ -683,18 +683,18 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
MS_LOG(INFO) << cnode_name << " can not do quant";
|
||||
continue;
|
||||
}
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
continue;
|
||||
}
|
||||
if (input_scale.find(cnode) == input_scale.end()) {
|
||||
primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE);
|
||||
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
|
||||
continue;
|
||||
}
|
||||
primitiveT_value->ClearInputOutputQuantParam();
|
||||
primitive_c->ClearInputOutputQuantParam();
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
auto op_type = (schema::PrimitiveType)primitiveT_value->Type();
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
MS_LOG(INFO) << "OpName: " << op_name;
|
||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||
op_type != PrimitiveType_FullConnection) {
|
||||
|
@ -715,35 +715,35 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||
MS_LOG(DEBUG) << "this parameter do quant";
|
||||
DoWeightQuant(input_node, primitiveT_value, false, false);
|
||||
DoWeightQuant(input_node, primitive_c, false, false);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
||||
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
if (input_cnode_primitiveT_value == nullptr) {
|
||||
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
if (input_cnode_primitive_c == nullptr) {
|
||||
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
|
||||
<< " PrimitiveC is null";
|
||||
continue;
|
||||
}
|
||||
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
|
||||
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
||||
primitiveT_value->AddInputQuantParam(quant_param);
|
||||
if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) {
|
||||
for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) {
|
||||
primitive_c->AddInputQuantParam(quant_param);
|
||||
}
|
||||
} else {
|
||||
// do input quant
|
||||
double scale = input_scale[cnode];
|
||||
int32_t zp = input_zero_point[cnode];
|
||||
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
|
||||
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// do input quant
|
||||
double scale = input_scale[cnode];
|
||||
int32_t convInputzeropoint = input_zero_point[cnode];
|
||||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
||||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
|
||||
// do weight quant
|
||||
auto weight = cnode->input(2);
|
||||
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
||||
|
@ -751,18 +751,18 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
if (op_type == PrimitiveType_FullConnection) {
|
||||
perchannel = false;
|
||||
}
|
||||
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
|
||||
DoWeightQuant(weight, primitive_c, perchannel, depthwise);
|
||||
// do bias quant
|
||||
if (cnode->inputs().size() == 4) {
|
||||
auto bias = cnode->input(3);
|
||||
DoBiasQuant(bias, primitiveT_value);
|
||||
DoBiasQuant(bias, primitive_c);
|
||||
}
|
||||
}
|
||||
// do output quant
|
||||
double OutputScale = output_scale[cnode];
|
||||
int32_t OutputZeropoint = output_zeropoint[cnode];
|
||||
DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value);
|
||||
primitiveT_value->SetQuantType(schema::QuantType_PostTraining);
|
||||
DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c);
|
||||
primitive_c->SetQuantType(schema::QuantType_PostTraining);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -95,10 +95,10 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, bool perchannel,
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel,
|
||||
bool depthwise);
|
||||
|
||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value);
|
||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
|
||||
};
|
||||
|
||||
struct DivergInfo;
|
||||
|
|
|
@ -44,17 +44,17 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|||
bool first = true;
|
||||
|
||||
for (auto &cnode : cnodes) {
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
auto curnode_quant_type = schema::QuantType_QUANT_NONE;
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
||||
} else {
|
||||
curnode_quant_type = primitiveT_value->GetQuantType();
|
||||
curnode_quant_type = primitive_c->GetQuantType();
|
||||
}
|
||||
if (first) {
|
||||
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
|
||||
auto value_node =
|
||||
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front());
|
||||
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front());
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
|
||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
|
||||
|
@ -72,24 +72,24 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|||
continue;
|
||||
}
|
||||
auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node);
|
||||
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
if (input_cnode_primitiveT_value == nullptr) {
|
||||
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
if (input_cnode_primitive_c == nullptr) {
|
||||
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
|
||||
<< " PrimitiveC is null";
|
||||
continue;
|
||||
}
|
||||
auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType();
|
||||
auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType();
|
||||
|
||||
if (curnode_quant_type != input_cnode_quant_type) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
if (curnode_quant_type == schema::QuantType_PostTraining &&
|
||||
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
|
||||
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
|
||||
primitiveT_value->GetInputQuantParams().front());
|
||||
primitive_c->GetInputQuantParams().front());
|
||||
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
|
||||
input_cnode_quant_type == schema::QuantType_PostTraining) {
|
||||
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
|
||||
input_cnode_primitiveT_value->GetInputQuantParams().front());
|
||||
input_cnode_primitive_c->GetInputQuantParams().front());
|
||||
}
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(WARNING) << "value_node is null! "
|
||||
|
|
|
@ -87,13 +87,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|||
}
|
||||
auto cnode = std::dynamic_pointer_cast<CNode>(node);
|
||||
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitiveT_value == nullptr) {
|
||||
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto type = (schema::PrimitiveType)primitiveT_value->Type();
|
||||
auto type = (schema::PrimitiveType)primitive_c->Type();
|
||||
MS_LOG(INFO) << "Primitive type: " << type;
|
||||
static const std::vector<schema::PrimitiveType> uint8OpList = {
|
||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||
|
@ -279,7 +279,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType,
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
|
||||
auto dims = weight->tensor_shape();
|
||||
if (per_channel) {
|
||||
|
@ -450,7 +450,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
MS_LOG(ERROR) << "quant_params empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
primitiveT_value->AddInputQuantParam(quant_params);
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
|||
}();
|
||||
}
|
||||
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType,
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||
int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false,
|
||||
bool depth_wise = false);
|
||||
|
||||
|
|
|
@ -135,6 +135,26 @@ void FreeInputTensor(std::vector<Tensor *> *input_tensor) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *lite_primitive = primitive_c->GetPrimitiveT();
|
||||
if (lite_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive in primitive_c is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::Primitive::Pack(builder, lite_primitive);
|
||||
builder.Finish(offset);
|
||||
auto buf = builder.GetBufferPointer();
|
||||
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
|
||||
return const_cast<schema::Primitive *>(primitive);
|
||||
}
|
||||
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
CheckIfFuncGraphIsNull(func_graph);
|
||||
|
@ -155,10 +175,16 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
||||
auto output_nums = GetOutputTensorNum(input_cnode);
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
||||
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
||||
primitiveT_value->InferShape(input_tensors, output_tensors);
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
|
||||
auto scheam_primitive = PackPrimitiveT(input_cnode);
|
||||
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
|
||||
if (lite_primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
|
||||
FreeInputTensor(&input_tensors);
|
||||
return nullptr;
|
||||
}
|
||||
lite_primitive->InferShape(input_tensors, output_tensors);
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
FreeInputTensor(&input_tensors);
|
||||
|
|
|
@ -62,17 +62,17 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
}
|
||||
auto conv_node = pre_node->cast<CNodePtr>();
|
||||
auto node_type = GetCNodeType(conv_node);
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitiveT_value);
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitive_c);
|
||||
if (node_type == schema::PrimitiveType_Conv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetActivationType(activation_type);
|
||||
return pre_node;
|
||||
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetActivationType(activation_type);
|
||||
return pre_node;
|
||||
|
|
|
@ -160,22 +160,22 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
|
|||
auto conv_node = conv_node_anf->cast<CNodePtr>();
|
||||
CheckIfCNodeIsNull(conv_node);
|
||||
GenConvNewBias(func_graph, conv_node, add_node);
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitiveT_value != nullptr);
|
||||
auto type = primitiveT_value->Type();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
auto type = primitive_c->Type();
|
||||
if (type == schema::PrimitiveType_Conv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else if (type == schema::PrimitiveType_DeConv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else {
|
||||
|
|
|
@ -115,14 +115,14 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|||
AnfNodePtr bn_scale_node = nullptr;
|
||||
AnfNodePtr bn_bias_node = nullptr;
|
||||
float eps = 0;
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0));
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0));
|
||||
if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) {
|
||||
bn_mean_node = bn_node->input(kCaffeBNMeanIndex);
|
||||
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
||||
CheckIfNodeIsParam(bn_mean_node);
|
||||
CheckIfNodeIsParam(bn_variance_node);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
eps = primc->GetEpsilon();
|
||||
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
|
||||
|
@ -130,8 +130,8 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|||
bn_bias_node = bn_node->input(kTFBNBiasIndex);
|
||||
bn_mean_node = bn_node->input(kTFBNMeanIndex);
|
||||
bn_variance_node = bn_node->input(kTFBNVarIndex);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
eps = primc->GetEpsilon();
|
||||
} else {
|
||||
|
|
|
@ -97,17 +97,17 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias);
|
||||
delete[] trans_bias;
|
||||
delete[] trans_scale;
|
||||
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitiveT_value != nullptr);
|
||||
auto type = primitiveT_value->Type();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
auto type = primitive_c->Type();
|
||||
if (type == schema::PrimitiveType_Conv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
|
||||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
|
||||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
primc->SetHasBias(true);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue