forked from mindspore-Ecosystem/mindspore
tf_deconv_bn_fusion
This commit is contained in:
parent
c2d9e1f396
commit
c5b5e4447d
|
@ -56,3 +56,6 @@ mtk_age_gender.pb 1
|
|||
mtk_model_ckpt.pb 1
|
||||
mtk_model_face_dress.pb 1;1,128,128,3
|
||||
mtk_model_normalize_object_scene_ps_20200519.pb 1;1,224,224,3
|
||||
ml_ocr_latin.pb 1
|
||||
ml_noya_tts_melgan.pb 1;16,16,80
|
||||
ml_video_edit_oneclick_adaptis.pb 3
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace lite {
|
|||
STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
MS_LOG(INFO) << "TF DeConvParser";
|
||||
MS_LOG(DEBUG) << "TF DeConvParser";
|
||||
if (primitiveC == nullptr || output_size == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -71,8 +71,8 @@ STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
}
|
||||
attr->kernelH = kernels[0];
|
||||
attr->kernelW = kernels[1];
|
||||
attr->channelIn = kernels[2];
|
||||
attr->channelOut = kernels[3];
|
||||
attr->channelOut = kernels[2];
|
||||
attr->channelIn = kernels[3];
|
||||
} else {
|
||||
attr->kernelH = -1;
|
||||
attr->kernelW = -1;
|
||||
|
|
|
@ -45,6 +45,27 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
primitive->value.type = schema::PrimitiveType_LogicalAnd;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
} else if (tf_op.op() == "LogicalOr") {
|
||||
auto attr = std::make_unique<schema::LogicalOrT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LogicalOr;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
} else if (tf_op.op() == "LogicalNot") {
|
||||
auto attr = std::make_unique<schema::LogicalNotT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_LogicalNot;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
} else {
|
||||
MS_LOG(ERROR) << tf_op.op() << " is not supported.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (*primitiveC == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveC is nullptr";
|
||||
|
@ -59,5 +80,7 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
return RET_OK;
|
||||
}
|
||||
TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser());
|
||||
TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser());
|
||||
TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -222,8 +222,16 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa
|
|||
return;
|
||||
}
|
||||
if (this->fmk_type_ == lite::converter::FmkType_TF) {
|
||||
for (int i = 0; i < weight_shape_size; i++) {
|
||||
tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num];
|
||||
auto group = primc->GetGroup();
|
||||
auto cin_group = weight_tensor->tensor_shape()[3] / group;
|
||||
int area_size = weight_tensor->tensor_shape()[0] * weight_tensor->tensor_shape()[1];
|
||||
for (int j = 0; j < area_size; j++) {
|
||||
for (int i = 0; i < kernel_num; ++i) {
|
||||
for (int k = 0; k < cin_group; ++k) {
|
||||
tmp_weight_data[k + i * cin_group + j * kernel_num * cin_group] =
|
||||
weight_data[k + i * cin_group + j * kernel_num * cin_group] * trans_scale[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto group = primc->GetGroup();
|
||||
|
|
Loading…
Reference in New Issue