!35965 modify adaper for SpaceToBatchND or BatchToSpaceND

Merge pull request !35965 from changzherui/mod_adapter
This commit is contained in:
i-robot 2022-06-17 16:01:02 +00:00 committed by Gitee
commit a9c564b83f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 67 additions and 24 deletions

View File

@ -552,6 +552,8 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
restore_checkpoint_sout_ << "digraph {" << endl;
// Convert ResizeBilinear attr size to input
ConvertResizeBilinear(anf_graph_);
// Convert SpaceBatch attr to input
ConvertSpaceBatchNd(anf_graph_);
// Convert Tile input1 to int32
ConvertTile(anf_graph_);
// Convert all anf node to Operator
@ -2324,6 +2326,50 @@ void DfGraphConvertor::ConvertResizeBilinear(const FuncGraphPtr anf_graph) {
}
}
void DfGraphConvertor::ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) {
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph);
for (auto &it : nodes) {
if (it->isa<CNode>()) {
auto node = it->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
std::string name = GetCNodeTargetFuncName(node);
if (name == prim::kPrimSpaceToBatchND->name() || name == prim::kPrimBatchToSpaceND->name()) {
AnfNodePtr op = node->input(0);
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
MS_EXCEPTION_IF_NULL(prim);
ValuePtr block_shape = prim->GetAttr("block_shape");
auto int64_value = GetValue<std::vector<int64_t>>(block_shape);
std::vector<int32_t> int32_value;
(void)std::transform(int64_value.begin(), int64_value.end(), std::back_inserter(int32_value), LongToInt);
auto new_value = NewValueNode(int32_value);
new_value->set_abstract(block_shape->ToAbstract());
node->add_input(new_value);
ValuePtr attr_value = nullptr;
if (name == prim::kPrimSpaceToBatchND->name()) {
attr_value = prim->GetAttr("paddings");
} else {
attr_value = prim->GetAttr("crops");
}
std::vector<int64_t> attr_list;
if (attr_value->isa<ValueList>()) {
const ValueListPtr &value = dyn_cast<ValueList>(attr_value);
for (const auto &item : value->value()) {
if (item->isa<ValueList>()) {
auto value_list = GetValue<std::vector<int64_t>>(item);
std::copy(value_list.begin(), value_list.end(), std::back_inserter(attr_list));
}
}
}
auto new_value_attr = NewValueNode(attr_list);
new_value_attr->set_abstract(attr_value->ToAbstract());
node->add_input(new_value_attr);
}
}
}
}
}
AnfNodePtr DfGraphConvertor::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const {
auto func_graph = input->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -172,6 +172,7 @@ class DfGraphConvertor {
void ConvertMakeTuple(const CNodePtr node);
void ConvertTopK(const CNodePtr node);
void ConvertResizeBilinear(const FuncGraphPtr anf_graph);
void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph);
void ConvertTile(const FuncGraphPtr anf_graph);
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const;
void ConvertReshape(const CNodePtr node);

View File

@ -97,11 +97,11 @@ ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MatMulV2, prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2))
// MatrixDiagD
INPUT_MAP(MatrixDiagD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}};
ATTR_MAP(MatrixDiagD) = EMPTY_ATTR_MAP;
OUTPUT_MAP(MatrixDiagD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MatrixDiagD, kNameMatrixDiagD, ADPT_DESC(MatrixDiagD))
// MatrixDiag
INPUT_MAP(MatrixDiag) = {{1, INPUT_DESC(x)}};
ATTR_MAP(MatrixDiag) = EMPTY_ATTR_MAP;
OUTPUT_MAP(MatrixDiag) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MatrixDiag, kNameMatrixDiagD, ADPT_DESC(MatrixDiag))
// MatrixDiagPartD
INPUT_MAP(MatrixDiagPartD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}};

View File

@ -68,8 +68,8 @@ DECLARE_OP_USE_OUTPUT(MatMul)
DECLARE_OP_ADAPTER(MatMulV2)
DECLARE_OP_USE_OUTPUT(MatMulV2)
DECLARE_OP_ADAPTER(MatrixDiagD)
DECLARE_OP_USE_OUTPUT(MatrixDiagD)
DECLARE_OP_ADAPTER(MatrixDiag)
DECLARE_OP_USE_OUTPUT(MatrixDiag)
DECLARE_OP_ADAPTER(MatrixDiagPartD)
DECLARE_OP_USE_OUTPUT(MatrixDiagPartD)

View File

@ -68,12 +68,10 @@ OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(SpaceToBatchD, kNameSpaceToBatch, ADPT_DESC(SpaceToBatchD))
// SpaceToBatchNDD
INPUT_MAP(SpaceToBatchNDD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(SpaceToBatchNDD) = {
{"block_shape", ATTR_DESC(block_shape, AnyTraits<std::vector<int64_t>>())},
{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(SpaceToBatchNDD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(SpaceToBatchNDD, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchNDD))
INPUT_MAP(SpaceToBatchND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(paddings)}};
ATTR_MAP(SpaceToBatchND) = EMPTY_ATTR_MAP;
OUTPUT_MAP(SpaceToBatchND) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(SpaceToBatchND, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchND))
// BatchToSpaceD
INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}};
@ -83,11 +81,9 @@ ATTR_MAP(BatchToSpaceD) = {
OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BatchToSpaceD, kNameBatchToSpace, ADPT_DESC(BatchToSpaceD))
// BatchToSpaceNDD
INPUT_MAP(BatchToSpaceNDD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BatchToSpaceNDD) = {
{"block_shape", ATTR_DESC(block_shape, AnyTraits<std::vector<int64_t>>())},
{"crops", ATTR_DESC(crops, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BatchToSpaceNDD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BatchToSpaceNDD, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceNDD))
// BatchToSpaceND
INPUT_MAP(BatchToSpaceND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(crops)}};
ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND))
} // namespace mindspore::transform

View File

@ -44,13 +44,13 @@ DECLARE_OP_USE_OUTPUT(DepthToSpace)
DECLARE_OP_ADAPTER(SpaceToBatchD)
DECLARE_OP_USE_OUTPUT(SpaceToBatchD)
DECLARE_OP_ADAPTER(SpaceToBatchNDD)
DECLARE_OP_USE_OUTPUT(SpaceToBatchNDD)
DECLARE_OP_ADAPTER(SpaceToBatchND)
DECLARE_OP_USE_OUTPUT(SpaceToBatchND)
DECLARE_OP_ADAPTER(BatchToSpaceD)
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
DECLARE_OP_ADAPTER(BatchToSpaceNDD)
DECLARE_OP_USE_OUTPUT(BatchToSpaceNDD)
DECLARE_OP_ADAPTER(BatchToSpaceND)
DECLARE_OP_USE_OUTPUT(BatchToSpaceND)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_TRANSFORMATION_OPS_DECLARE_H_