diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index f9ff5d88fec..abd6b347e18 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -552,6 +552,8 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { restore_checkpoint_sout_ << "digraph {" << endl; // Convert ResizeBilinear attr size to input ConvertResizeBilinear(anf_graph_); + // Convert SpaceBatch attr to input + ConvertSpaceBatchNd(anf_graph_); // Convert Tile input1 to int32 ConvertTile(anf_graph_); // Convert all anf node to Operator @@ -2324,6 +2326,50 @@ void DfGraphConvertor::ConvertResizeBilinear(const FuncGraphPtr anf_graph) { } } +void DfGraphConvertor::ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) { + std::vector nodes = GetOrderedCNodes(anf_graph); + for (auto &it : nodes) { + if (it->isa()) { + auto node = it->cast(); + MS_EXCEPTION_IF_NULL(node); + std::string name = GetCNodeTargetFuncName(node); + if (name == prim::kPrimSpaceToBatchND->name() || name == prim::kPrimBatchToSpaceND->name()) { + AnfNodePtr op = node->input(0); + if (IsValueNode(op)) { + auto prim = GetValueNode(op); + MS_EXCEPTION_IF_NULL(prim); + ValuePtr block_shape = prim->GetAttr("block_shape"); + auto int64_value = GetValue>(block_shape); + std::vector int32_value; + (void)std::transform(int64_value.begin(), int64_value.end(), std::back_inserter(int32_value), LongToInt); + auto new_value = NewValueNode(int32_value); + new_value->set_abstract(block_shape->ToAbstract()); + node->add_input(new_value); + ValuePtr attr_value = nullptr; + if (name == prim::kPrimSpaceToBatchND->name()) { + attr_value = prim->GetAttr("paddings"); + } else { + attr_value = prim->GetAttr("crops"); + } + std::vector attr_list; + if (attr_value->isa()) { + const ValueListPtr &value = dyn_cast(attr_value); + for (const auto &item : value->value()) { + if (item->isa()) { + auto value_list = GetValue>(item); + std::copy(value_list.begin(), value_list.end(), std::back_inserter(attr_list)); + } + } + } + auto new_value_attr = NewValueNode(attr_list); + new_value_attr->set_abstract(attr_value->ToAbstract()); + node->add_input(new_value_attr); + } + } + } + } +} + AnfNodePtr DfGraphConvertor::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const { auto func_graph = input->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index 4eefd512558..3746aaa2278 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -172,6 +172,7 @@ class DfGraphConvertor { void ConvertMakeTuple(const CNodePtr node); void ConvertTopK(const CNodePtr node); void ConvertResizeBilinear(const FuncGraphPtr anf_graph); + void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph); void ConvertTile(const FuncGraphPtr anf_graph); AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const; void ConvertReshape(const CNodePtr node); diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.cc index af0b6f84b57..af1d805b149 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.cc @@ -97,11 +97,11 @@ ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits()) OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; REG_ADPT_DESC(MatMulV2, prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2)) -// MatrixDiagD -INPUT_MAP(MatrixDiagD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}}; -ATTR_MAP(MatrixDiagD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(MatrixDiagD) = {{0, OUTPUT_DESC(y)}}; -REG_ADPT_DESC(MatrixDiagD, kNameMatrixDiagD, ADPT_DESC(MatrixDiagD)) +// MatrixDiag +INPUT_MAP(MatrixDiag) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MatrixDiag) = EMPTY_ATTR_MAP; +OUTPUT_MAP(MatrixDiag) = {{0, OUTPUT_DESC(y)}}; +REG_ADPT_DESC(MatrixDiag, kNameMatrixDiagD, ADPT_DESC(MatrixDiag)) // MatrixDiagPartD INPUT_MAP(MatrixDiagPartD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}}; diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.h index 9a9f7723e3e..fdd2f0dbee4 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/matrix_calculation_ops_declare.h @@ -68,8 +68,8 @@ DECLARE_OP_USE_OUTPUT(MatMul) DECLARE_OP_ADAPTER(MatMulV2) DECLARE_OP_USE_OUTPUT(MatMulV2) -DECLARE_OP_ADAPTER(MatrixDiagD) -DECLARE_OP_USE_OUTPUT(MatrixDiagD) +DECLARE_OP_ADAPTER(MatrixDiag) +DECLARE_OP_USE_OUTPUT(MatrixDiag) DECLARE_OP_ADAPTER(MatrixDiagPartD) DECLARE_OP_USE_OUTPUT(MatrixDiagPartD) diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.cc index bb9941dbea8..a086ff48ca4 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.cc @@ -68,12 +68,10 @@ OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}}; REG_ADPT_DESC(SpaceToBatchD, kNameSpaceToBatch, ADPT_DESC(SpaceToBatchD)) // SpaceToBatchNDD -INPUT_MAP(SpaceToBatchNDD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(SpaceToBatchNDD) = { - {"block_shape", ATTR_DESC(block_shape, AnyTraits>())}, - {"paddings", ATTR_DESC(paddings, AnyTraits>>(), AnyTraits>())}}; -OUTPUT_MAP(SpaceToBatchNDD) = {{0, OUTPUT_DESC(y)}}; -REG_ADPT_DESC(SpaceToBatchNDD, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchNDD)) +INPUT_MAP(SpaceToBatchND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(paddings)}}; +ATTR_MAP(SpaceToBatchND) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SpaceToBatchND) = {{0, OUTPUT_DESC(y)}}; +REG_ADPT_DESC(SpaceToBatchND, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchND)) // BatchToSpaceD INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}}; @@ -83,11 +81,9 @@ ATTR_MAP(BatchToSpaceD) = { OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}}; REG_ADPT_DESC(BatchToSpaceD, kNameBatchToSpace, ADPT_DESC(BatchToSpaceD)) -// BatchToSpaceNDD -INPUT_MAP(BatchToSpaceNDD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(BatchToSpaceNDD) = { - {"block_shape", ATTR_DESC(block_shape, AnyTraits>())}, - {"crops", ATTR_DESC(crops, AnyTraits>>(), AnyTraits>())}}; -OUTPUT_MAP(BatchToSpaceNDD) = {{0, OUTPUT_DESC(y)}}; -REG_ADPT_DESC(BatchToSpaceNDD, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceNDD)) +// BatchToSpaceND +INPUT_MAP(BatchToSpaceND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(crops)}}; +ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP; +OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}}; +REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.h index 453a908b8c4..647b60ce6da 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/transformation_ops_declare.h @@ -44,13 +44,13 @@ DECLARE_OP_USE_OUTPUT(DepthToSpace) DECLARE_OP_ADAPTER(SpaceToBatchD) DECLARE_OP_USE_OUTPUT(SpaceToBatchD) -DECLARE_OP_ADAPTER(SpaceToBatchNDD) -DECLARE_OP_USE_OUTPUT(SpaceToBatchNDD) +DECLARE_OP_ADAPTER(SpaceToBatchND) +DECLARE_OP_USE_OUTPUT(SpaceToBatchND) DECLARE_OP_ADAPTER(BatchToSpaceD) DECLARE_OP_USE_OUTPUT(BatchToSpaceD) -DECLARE_OP_ADAPTER(BatchToSpaceNDD) -DECLARE_OP_USE_OUTPUT(BatchToSpaceNDD) +DECLARE_OP_ADAPTER(BatchToSpaceND) +DECLARE_OP_USE_OUTPUT(BatchToSpaceND) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_TRANSFORMATION_OPS_DECLARE_H_