add unknown_shape attr for fusion json

This commit is contained in:
lby 2021-09-26 10:03:14 +08:00
parent 6f1535911a
commit 925778015d
3 changed files with 9 additions and 2 deletions

View File

@ -311,4 +311,10 @@ bool FusionBuildTbeJsonCreator::AttrsJsonPostProcessing(const AnfNodePtr &anf_no
// tbe::TbeAdapter::CastAttrJsonPost(anf_node, attrs_json);
return true;
}
void FusionBuildTbeJsonCreator::GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(compute_json);
(*compute_json)[kJUnknowShape] = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node);
}
} // namespace mindspore::kernel

View File

@ -38,6 +38,7 @@ class FusionBuildTbeJsonCreator : public TbeJsonCreator {
std::vector<nlohmann::json> *op_list_json, const ANodeFusionDataTypeMap &spec_data_input);
bool AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
nlohmann::json *attrs_json) override;
void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) override;
private:
AnfNodePtr GetInputCNode(const AnfNodePtr &node, const nlohmann::json &input_desc);

View File

@ -306,7 +306,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4464178465553346953U);
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 4295704009218326208U);
}
TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
@ -365,7 +365,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
nlohmann::json fusion_json;
auto tbe_json_creator = std::make_shared<FusionBuildTbeJsonCreator>();
EXPECT_TRUE(tbe_json_creator->GenJson(fusion_scope_info, &fusion_json));
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 2100526894749019474U);
EXPECT_EQ(tbe_json_creator->GetJsonHash(), 3224628976775251376U);
}
} // namespace mindspore::kernel