forked from mindspore-Ecosystem/mindspore
!35965 modify adaper for SpaceToBatchND or BatchToSpaceND
Merge pull request !35965 from changzherui/mod_adapter
This commit is contained in:
commit
a9c564b83f
|
@ -552,6 +552,8 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|||
restore_checkpoint_sout_ << "digraph {" << endl;
|
||||
// Convert ResizeBilinear attr size to input
|
||||
ConvertResizeBilinear(anf_graph_);
|
||||
// Convert SpaceBatch attr to input
|
||||
ConvertSpaceBatchNd(anf_graph_);
|
||||
// Convert Tile input1 to int32
|
||||
ConvertTile(anf_graph_);
|
||||
// Convert all anf node to Operator
|
||||
|
@ -2324,6 +2326,50 @@ void DfGraphConvertor::ConvertResizeBilinear(const FuncGraphPtr anf_graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) {
|
||||
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph);
|
||||
for (auto &it : nodes) {
|
||||
if (it->isa<CNode>()) {
|
||||
auto node = it->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::string name = GetCNodeTargetFuncName(node);
|
||||
if (name == prim::kPrimSpaceToBatchND->name() || name == prim::kPrimBatchToSpaceND->name()) {
|
||||
AnfNodePtr op = node->input(0);
|
||||
if (IsValueNode<Primitive>(op)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
ValuePtr block_shape = prim->GetAttr("block_shape");
|
||||
auto int64_value = GetValue<std::vector<int64_t>>(block_shape);
|
||||
std::vector<int32_t> int32_value;
|
||||
(void)std::transform(int64_value.begin(), int64_value.end(), std::back_inserter(int32_value), LongToInt);
|
||||
auto new_value = NewValueNode(int32_value);
|
||||
new_value->set_abstract(block_shape->ToAbstract());
|
||||
node->add_input(new_value);
|
||||
ValuePtr attr_value = nullptr;
|
||||
if (name == prim::kPrimSpaceToBatchND->name()) {
|
||||
attr_value = prim->GetAttr("paddings");
|
||||
} else {
|
||||
attr_value = prim->GetAttr("crops");
|
||||
}
|
||||
std::vector<int64_t> attr_list;
|
||||
if (attr_value->isa<ValueList>()) {
|
||||
const ValueListPtr &value = dyn_cast<ValueList>(attr_value);
|
||||
for (const auto &item : value->value()) {
|
||||
if (item->isa<ValueList>()) {
|
||||
auto value_list = GetValue<std::vector<int64_t>>(item);
|
||||
std::copy(value_list.begin(), value_list.end(), std::back_inserter(attr_list));
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_value_attr = NewValueNode(attr_list);
|
||||
new_value_attr->set_abstract(attr_value->ToAbstract());
|
||||
node->add_input(new_value_attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr DfGraphConvertor::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const {
|
||||
auto func_graph = input->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -172,6 +172,7 @@ class DfGraphConvertor {
|
|||
void ConvertMakeTuple(const CNodePtr node);
|
||||
void ConvertTopK(const CNodePtr node);
|
||||
void ConvertResizeBilinear(const FuncGraphPtr anf_graph);
|
||||
void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph);
|
||||
void ConvertTile(const FuncGraphPtr anf_graph);
|
||||
AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const;
|
||||
void ConvertReshape(const CNodePtr node);
|
||||
|
|
|
@ -97,11 +97,11 @@ ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits<bool>())
|
|||
OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MatMulV2, prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2))
|
||||
|
||||
// MatrixDiagD
|
||||
INPUT_MAP(MatrixDiagD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}};
|
||||
ATTR_MAP(MatrixDiagD) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(MatrixDiagD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MatrixDiagD, kNameMatrixDiagD, ADPT_DESC(MatrixDiagD))
|
||||
// MatrixDiag
|
||||
INPUT_MAP(MatrixDiag) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(MatrixDiag) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(MatrixDiag) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MatrixDiag, kNameMatrixDiagD, ADPT_DESC(MatrixDiag))
|
||||
|
||||
// MatrixDiagPartD
|
||||
INPUT_MAP(MatrixDiagPartD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(assist)}};
|
||||
|
|
|
@ -68,8 +68,8 @@ DECLARE_OP_USE_OUTPUT(MatMul)
|
|||
DECLARE_OP_ADAPTER(MatMulV2)
|
||||
DECLARE_OP_USE_OUTPUT(MatMulV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(MatrixDiagD)
|
||||
DECLARE_OP_USE_OUTPUT(MatrixDiagD)
|
||||
DECLARE_OP_ADAPTER(MatrixDiag)
|
||||
DECLARE_OP_USE_OUTPUT(MatrixDiag)
|
||||
|
||||
DECLARE_OP_ADAPTER(MatrixDiagPartD)
|
||||
DECLARE_OP_USE_OUTPUT(MatrixDiagPartD)
|
||||
|
|
|
@ -68,12 +68,10 @@ OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}};
|
|||
REG_ADPT_DESC(SpaceToBatchD, kNameSpaceToBatch, ADPT_DESC(SpaceToBatchD))
|
||||
|
||||
// SpaceToBatchNDD
|
||||
INPUT_MAP(SpaceToBatchNDD) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(SpaceToBatchNDD) = {
|
||||
{"block_shape", ATTR_DESC(block_shape, AnyTraits<std::vector<int64_t>>())},
|
||||
{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(SpaceToBatchNDD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(SpaceToBatchNDD, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchNDD))
|
||||
INPUT_MAP(SpaceToBatchND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(paddings)}};
|
||||
ATTR_MAP(SpaceToBatchND) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(SpaceToBatchND) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(SpaceToBatchND, kNameSpaceToBatchNDD, ADPT_DESC(SpaceToBatchND))
|
||||
|
||||
// BatchToSpaceD
|
||||
INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}};
|
||||
|
@ -83,11 +81,9 @@ ATTR_MAP(BatchToSpaceD) = {
|
|||
OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(BatchToSpaceD, kNameBatchToSpace, ADPT_DESC(BatchToSpaceD))
|
||||
|
||||
// BatchToSpaceNDD
|
||||
INPUT_MAP(BatchToSpaceNDD) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(BatchToSpaceNDD) = {
|
||||
{"block_shape", ATTR_DESC(block_shape, AnyTraits<std::vector<int64_t>>())},
|
||||
{"crops", ATTR_DESC(crops, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(BatchToSpaceNDD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(BatchToSpaceNDD, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceNDD))
|
||||
// BatchToSpaceND
|
||||
INPUT_MAP(BatchToSpaceND) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(block_shape)}, {3, INPUT_DESC(crops)}};
|
||||
ATTR_MAP(BatchToSpaceND) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(BatchToSpaceND) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(BatchToSpaceND, kNameBatchToSpaceNd, ADPT_DESC(BatchToSpaceND))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -44,13 +44,13 @@ DECLARE_OP_USE_OUTPUT(DepthToSpace)
|
|||
DECLARE_OP_ADAPTER(SpaceToBatchD)
|
||||
DECLARE_OP_USE_OUTPUT(SpaceToBatchD)
|
||||
|
||||
DECLARE_OP_ADAPTER(SpaceToBatchNDD)
|
||||
DECLARE_OP_USE_OUTPUT(SpaceToBatchNDD)
|
||||
DECLARE_OP_ADAPTER(SpaceToBatchND)
|
||||
DECLARE_OP_USE_OUTPUT(SpaceToBatchND)
|
||||
|
||||
DECLARE_OP_ADAPTER(BatchToSpaceD)
|
||||
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
|
||||
|
||||
DECLARE_OP_ADAPTER(BatchToSpaceNDD)
|
||||
DECLARE_OP_USE_OUTPUT(BatchToSpaceNDD)
|
||||
DECLARE_OP_ADAPTER(BatchToSpaceND)
|
||||
DECLARE_OP_USE_OUTPUT(BatchToSpaceND)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_TRANSFORMATION_OPS_DECLARE_H_
|
||||
|
|
Loading…
Reference in New Issue