forked from mindspore-Ecosystem/mindspore
add unknown_shape attr for fusion json
This commit is contained in:
parent
6f1535911a
commit
925778015d
|
@ -311,4 +311,10 @@ bool FusionBuildTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_no
|
||||||
// tbe::TbeAdapter::CastAttrJsonPost(anf_node, attrs_json);
|
// tbe::TbeAdapter::CastAttrJsonPost(anf_node, attrs_json);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FusionBuildTbeJsonCreator::GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(compute_json);
|
||||||
|
(*compute_json)[kJUnknowShape] = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node);
|
||||||
|
}
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -38,6 +38,7 @@ class FusionBuildTbeJsonCreator : public TbeJsonCreator {
|
||||||
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input);
|
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input);
|
||||||
bool AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
|
bool AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
|
||||||
nlohmann::json *attrs_json) override;
|
nlohmann::json *attrs_json) override;
|
||||||
|
void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AnfNodePtr GetInputCNode(const AnfNodePtr &node, const nlohmann::json &input_desc);
|
AnfNodePtr GetInputCNode(const AnfNodePtr &node, const nlohmann::json &input_desc);
|
||||||
|
|
|
@ -306,7 +306,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
|
||||||
nlohmann::json fusion_json;
|
nlohmann::json fusion_json;
|
||||||
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
||||||
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
||||||
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4464178465553346953U);
|
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4295704009218326208U);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
|
TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
|
||||||
|
@ -365,7 +365,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
|
||||||
nlohmann::json fusion_json;
|
nlohmann::json fusion_json;
|
||||||
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
|
||||||
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
|
||||||
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 2100526894749019474U);
|
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 3224628976775251376U);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
Loading…
Reference in New Issue