Add Custom actor nodes and custom actor support for dynamic shape case

This commit is contained in:
TronZhang 2022-01-05 09:44:36 +08:00
parent 1b08f35ef1
commit 58f386fe2a
30 changed files with 1893 additions and 39 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -153,6 +153,10 @@
#include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h"
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
#include "backend/optimizer/ascend/dynamic_shape/convert_dynamic_op.h"
#include "backend/optimizer/ascend/dynamic_shape/convert_general_op.h"
#include "backend/optimizer/ascend/dynamic_shape/convert_inherited_dynamic_op.h"
#include "backend/optimizer/ascend/dynamic_shape/link_custom_op.h"
#include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
#include "utils/ms_context.h"
@ -618,5 +622,34 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
#endif
}
void AscendDynamicShapeConvert(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#ifdef ENABLE_DUMP_IR
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs) {
std::string file_name =
"hwopt_d_before_dynamic_shape_convert_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_name, kernel_graph);
DumpIRProto(kernel_graph, "before_dynamic_shape_convert_hwopt_" + std::to_string(kernel_graph->graph_id()));
}
#endif
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertDynamicOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertGeneralOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertInheritedDynamicOp>());
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
optimizer->AddPassManager(dynamic_shape_convert_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
#ifdef ENABLE_DUMP_IR
if (save_graphs) {
std::string file_name =
"hwopt_d_after_dynamic_shape_convert_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_name, kernel_graph);
}
#endif
}
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendDynamicShapeConvert(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,130 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "utils/anf_utils.h"
namespace mindspore {
namespace {
bool IsRealCNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n)) {
CNodePtr cnode = utils::cast<CNodePtr>(n);
return AnfUtils::IsRealKernel(cnode);
}
return false;
}
} // namespace
namespace opt::dynamic_shape {
bool IsGeneralOp(const BaseRef &n) {
if (IsDynamicOp(n)) {
return false;
}
if (IsInheritedDynamicOp(n)) {
return false;
}
return IsRealCNode(n);
}
bool IsDynamicOp(const BaseRef &n) {
if (!IsRealCNode(n)) {
return false;
}
CNodePtr cnode = utils::cast<CNodePtr>(n);
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
return kComputeDepend.find(op_name) != kComputeDepend.end();
}
bool IsInheritedDynamicOp(const BaseRef &n) {
if (IsDynamicOp(n)) {
return false;
}
if (!IsRealCNode(n)) {
return false;
}
CNodePtr cnode = utils::cast<CNodePtr>(n);
MS_EXCEPTION_IF_NULL(cnode);
return AnfAlgo::IsNodeInputDynamicShape(cnode) || AnfUtils::IsNodeOutputDynamicShape(cnode);
}
AnfNodePtr GenInferNode(const AnfNodePtr &node, bool fake_flag) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfUtils::CustomActorCallback actor_func;
if (fake_flag) {
actor_func = [](void *) -> void { return; };
} else {
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
actor_func = [kernel_mod](void *) { kernel_mod->InferOp(); };
}
auto infer_node = AnfUtils::NewInferActorNode(actor_func, cnode, fake_flag);
infer_node->set_kernel_info(std::make_shared<device::KernelInfo>());
return infer_node;
}
AnfNodePtr GenInitNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto init_node = AnfUtils::NewInitActorNode([kernel_mod](void *) { kernel_mod->InitOp(); }, cnode);
init_node->set_kernel_info(std::make_shared<device::KernelInfo>());
return init_node;
}
AnfNodePtr GenUpdateNode(const AnfNodePtr &node, bool just_sync_flag) {
// Some not dynamic shape node should sync after launch for latter node.
// Use a flag `just_sync_flag` to distinguish them with dynamic ones.
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto update_node =
AnfUtils::NewUpdateActorNode([kernel_mod](void *) { kernel_mod->UpdateOp(); }, cnode, just_sync_flag);
update_node->set_kernel_info(std::make_shared<device::KernelInfo>());
return update_node;
}
bool IsDynUpdate(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto custom_actor_type = AnfUtils::GetCustomActorType(node);
if (custom_actor_type != kUpdate) {
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom update node!";
}
return !AnfUtils::GetCustomActorJustSyncFlag(node);
}
CustomActorNodeManager &CustomActorNodeManager::Instance() {
static CustomActorNodeManager instance{};
return instance;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H
#include <string>
#include "ir/anf.h"
#include "utils/ms_utils.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore::opt::dynamic_shape {
bool IsGeneralOp(const BaseRef &n);
bool IsDynamicOp(const BaseRef &n);
bool IsInheritedDynamicOp(const BaseRef &n);
AnfNodePtr GenInferNode(const AnfNodePtr &node, bool fake_flag = false);
AnfNodePtr GenInitNode(const AnfNodePtr &node);
AnfNodePtr GenUpdateNode(const AnfNodePtr &node, bool just_sync_flag = false);
bool IsDynUpdate(const AnfNodePtr &node);
struct RelatedCustomActorNode {
AnfNodePtr infer_node;
AnfNodePtr init_node;
AnfNodePtr update_node;
};
class CustomActorNodeManager {
public:
static CustomActorNodeManager &Instance();
void Reset() { custom_nodes_map_.clear(); }
void Register(const AnfNodePtr &node, const RelatedCustomActorNode &custom_nodes) {
custom_nodes_map_.emplace(node, custom_nodes);
}
bool IsRegistered(const AnfNodePtr &node) const { return custom_nodes_map_.find(node) != custom_nodes_map_.end(); }
const RelatedCustomActorNode &GetCustomActorNodes(const AnfNodePtr &node) const {
if (auto iter = custom_nodes_map_.find(node); iter != custom_nodes_map_.end()) {
return iter->second;
}
MS_LOG(EXCEPTION) << "Not registered node!";
}
private:
CustomActorNodeManager() = default;
~CustomActorNodeManager() = default;
DISABLE_COPY_AND_ASSIGN(CustomActorNodeManager)
OrderedMap<AnfNodePtr, RelatedCustomActorNode> custom_nodes_map_;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/dynamic_shape/convert_dynamic_op.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
namespace mindspore {
namespace opt::dynamic_shape {
const BaseRef ConvertDynamicOp::DefinePattern() const {
VarPtr X = std::make_shared<CondVar>(IsDynamicOp);
return BaseRef({X});
}
const AnfNodePtr ConvertDynamicOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto infer_node = GenInferNode(node);
auto init_node = GenInitNode(node);
auto update_node = GenUpdateNode(node);
RelatedCustomActorNode custom_nodes = {infer_node, init_node, update_node};
CustomActorNodeManager::Instance().Register(node, custom_nodes);
return node;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class ConvertDynamicOp : public PatternProcessPass {
public:
explicit ConvertDynamicOp(bool multigraph = true) : PatternProcessPass("convert_dynamic_op", multigraph) {}
~ConvertDynamicOp() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/dynamic_shape/convert_general_op.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
namespace mindspore {
namespace opt::dynamic_shape {
const BaseRef ConvertGeneralOp::DefinePattern() const {
VarPtr X = std::make_shared<CondVar>(IsGeneralOp);
return BaseRef({X});
}
const AnfNodePtr ConvertGeneralOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto stub_infer_node = GenInferNode(node, true);
auto init_node = GenInitNode(node);
auto sync_node = GenUpdateNode(node, true); // Use to call sync if needed.
RelatedCustomActorNode custom_nodes = {stub_infer_node, init_node, sync_node};
CustomActorNodeManager::Instance().Register(node, custom_nodes);
return node;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class ConvertGeneralOp : public PatternProcessPass {
public:
explicit ConvertGeneralOp(bool multigraph = true) : PatternProcessPass("convert_general_op", multigraph) {}
~ConvertGeneralOp() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/dynamic_shape/convert_inherited_dynamic_op.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
namespace mindspore {
namespace opt::dynamic_shape {
const BaseRef ConvertInheritedDynamicOp::DefinePattern() const {
VarPtr X = std::make_shared<CondVar>(IsInheritedDynamicOp);
return BaseRef({X});
}
const AnfNodePtr ConvertInheritedDynamicOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto infer_node = GenInferNode(node);
auto init_node = GenInitNode(node);
auto sync_node = GenUpdateNode(node, true); // Use to call sync if needed.
RelatedCustomActorNode custom_nodes = {infer_node, init_node, sync_node};
CustomActorNodeManager::Instance().Register(node, custom_nodes);
return node;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_INHERITED_DYNAMIC_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_INHERITED_DYNAMIC_OP_H
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class ConvertInheritedDynamicOp : public PatternProcessPass {
public:
explicit ConvertInheritedDynamicOp(bool multigraph = true)
: PatternProcessPass("convert_inherited_dynamic_op", multigraph) {}
~ConvertInheritedDynamicOp() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H

View File

@ -0,0 +1,212 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/dynamic_shape/link_custom_op.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
namespace mindspore {
namespace opt::dynamic_shape {
namespace {
constexpr size_t kTupleFirstItemIndex = 0;
constexpr size_t kFirstDataInputIndex = 1;
AnfNodePtr InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next) {
MS_EXCEPTION_IF_NULL(g);
MS_EXCEPTION_IF_NULL(prev);
MS_EXCEPTION_IF_NULL(next);
// add depend from prev to next
auto depend_node = g->NewCNode(
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
MS_EXCEPTION_IF_NULL(depend_node);
return depend_node;
}
bool LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
MS_EXCEPTION_IF_NULL(g);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(depend_nodes);
bool changed = false;
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
if (custom_nodes.infer_node != nullptr) {
if (custom_nodes.init_node == nullptr) {
MS_LOG(WARNING) << "Node " << node->DebugString() << " has infer node but init node is null.";
} else {
depend_nodes->push_back(InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node)); // link infer => init
depend_nodes->push_back(InsertDepend(g, custom_nodes.init_node, node)); // link init => launch
changed = true;
}
}
if (IsDynUpdate(custom_nodes.update_node)) {
depend_nodes->push_back(InsertDepend(g, node, custom_nodes.update_node)); // link launch => update
changed = true;
}
return changed;
}
bool LinkInputOp(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
MS_EXCEPTION_IF_NULL(g);
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(depend_nodes);
bool changed = false;
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
if (custom_nodes.infer_node == nullptr) {
return changed;
}
size_t input_num = AnfAlgo::GetInputNum(cnode);
for (size_t i = 0; i < input_num; ++i) {
auto prev = AnfAlgo::GetPrevNodeOutput(cnode, i);
const auto &prev_node = prev.first;
if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
continue;
}
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
if (prev_custom_nodes.infer_node != nullptr) {
depend_nodes->push_back(
InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node)); // link prev.infer => curr.infer
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " infer "
<< prev_custom_nodes.infer_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
changed = true;
}
if (IsDynUpdate(prev_custom_nodes.update_node)) {
depend_nodes->push_back(
InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node)); // link prev.update => curr.infer
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " update "
<< prev_custom_nodes.update_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
changed = true;
}
}
return changed;
}
bool LinkDependSync(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
MS_EXCEPTION_IF_NULL(g);
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(depend_nodes);
bool changed = false;
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
if (custom_nodes.infer_node == nullptr) {
return changed;
}
auto dynamic_shape_depends = abstract::GetDependsFormMap(cnode);
if (dynamic_shape_depends.empty()) {
return changed;
}
for (auto depend_index : dynamic_shape_depends) {
auto prev = AnfAlgo::GetPrevNodeOutput(cnode, depend_index);
const auto &prev_node = prev.first;
if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
continue;
}
// If previous node is dynamic, so it was already link.
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
if (IsDynUpdate(prev_custom_nodes.update_node)) {
continue;
}
// 1. Link prev_node => prev_node.update if its update is just sync.
depend_nodes->push_back(InsertDepend(g, prev_node, prev_custom_nodes.update_node));
// 2. Link prev_node.update => cur_node.infer.
depend_nodes->push_back(InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node));
changed = true;
}
return changed;
}
/**
* @brief Attach Custom's Depend nodes with additional MakeTuple and TupleGetItem before graph return.
*
* %0 = A
* return %0
* ---->
* %0 = A
* %1 = MakeTuple(%0, %depend0, %depend1...)
* %2 = TupleGetItem(%1, 0)
* return %2
*
* @param g Graph.
* @param depend_nodes Custom's Depend nodes.
*/
void AttachDependNodes(const FuncGraphPtr &g, const AnfNodePtrList &depend_nodes) {
if (depend_nodes.empty()) {
return;
}
MS_EXCEPTION_IF_NULL(g);
auto return_node = g->get_return();
MS_EXCEPTION_IF_NULL(return_node);
// New MakeTuple node
auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())),
return_node->input(kFirstDataInputIndex)};
mk_inputs.insert(mk_inputs.end(), depend_nodes.begin(), depend_nodes.end());
auto make_tuple_node = g->NewCNode(mk_inputs);
// Get first element item form that maketuple and return.
auto get_1st_item = g->NewCNode(AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
make_tuple_node, NewValueNode(SizeToLong(kTupleFirstItemIndex))});
// Attach back.
return_node->set_input(kFirstDataInputIndex, get_1st_item);
}
} // namespace
bool LinkCustomOp::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
AnfNodePtrList depend_nodes;
auto node_list = TopoSort(func_graph->get_return());
for (const auto &node : node_list) {
CNodePtr cnode = node->cast<CNodePtr>();
if (cnode == nullptr || !CustomActorNodeManager::Instance().IsRegistered(cnode)) {
continue;
}
changed = LinkInternalOp(func_graph, cnode, &depend_nodes) || changed;
changed = LinkInputOp(func_graph, cnode, &depend_nodes) || changed;
changed = LinkDependSync(func_graph, cnode, &depend_nodes) || changed;
}
CustomActorNodeManager::Instance().Reset();
if (changed) {
AttachDependNodes(func_graph, depend_nodes);
// Rebuild graph's edge.
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt::dynamic_shape
} // namespace mindspore

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_LINK_CUSTOM_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_LINK_CUSTOM_OP_H
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore::opt::dynamic_shape {
class LinkCustomOp : public Pass {
public:
LinkCustomOp() : Pass("link_custom_op") {}
~LinkCustomOp() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::opt::dynamic_shape
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/common/helper.h"
#include "utils/anf_utils.h"
namespace mindspore {
namespace session {
@ -878,7 +879,7 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
auto node = que.front();
que.pop();
MS_EXCEPTION_IF_NULL(node);
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
if (node->isa<Parameter>() || node->isa<ValueNode>() || AnfUtils::IsCustomActorNode(node)) {
seed_nodes->push(node);
continue;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@
#include "pipeline/jit/base.h"
#include "debug/trace.h"
#include "utils/trace_base.h"
#include "utils/anf_utils.h"
namespace mindspore {
const std::string ToShortString(const TypeId &typeId) {
@ -257,6 +258,32 @@ void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
}
}
void DumpParamterInOperand(const AnfNodePtr &node, const AnfNodePtr &in, OrderedMap<AnfNodePtr, int32_t> *para_map,
const std::shared_ptr<SubGraphIRInfo> &gsub) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
MS_EXCEPTION_IF_NULL(in);
MS_EXCEPTION_IF_NULL(para_map);
MS_EXCEPTION_IF_NULL(gsub);
if (in->func_graph() == nullptr) {
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
} else {
gsub->buffer << "%";
}
auto iter = para_map->find(in);
if (iter == para_map->end()) {
gsub->buffer << "para_" << in->ToString();
} else {
gsub->buffer << "para" << iter->second << "_" << in->ToString();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << ")";
}
}
void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_map,
const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (node == nullptr || para_map == nullptr || gsub == nullptr) {
@ -275,24 +302,7 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
gsub->buffer << ", ";
}
if (in->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(node->func_graph());
if (in->func_graph() == nullptr) {
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
} else {
gsub->buffer << "%";
}
auto iter = para_map->find(in);
if (iter == para_map->end()) {
gsub->buffer << "para_" << in->ToString();
} else {
gsub->buffer << "para" << iter->second << "_" << in->ToString();
}
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
gsub->buffer << ")";
}
DumpParamterInOperand(node, in, para_map, gsub);
} else if (in->isa<CNode>()) {
auto iter = gsub->local_var_map.find(in);
if (iter != gsub->local_var_map.end()) {
@ -308,6 +318,8 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
} else if (IsValueNode<FuncGraph>(in)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(in);
gsub->buffer << "@" << fg->ToString();
} else if (AnfUtils::IsCustomActorNode(in)) {
gsub->buffer << "%" << AnfUtils::GetCustomActorName(in);
} else {
gsub->buffer << in->ToString();
}
@ -570,6 +582,8 @@ void DumpIRInSubgraph(const std::vector<AnfNodePtr> &nodes, OrderedMap<AnfNodePt
if (node->isa<CNode>()) {
// Print and record output of operator if it is not 'Return'
DumpCNode(node->cast<CNodePtr>(), sub_graph, para_map, gsub, dump_full_name, dump_location);
} else if (AnfUtils::IsCustomActorNode(node)) {
continue;
} else {
gsub->buffer << " " << node->ToString() << std::endl;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@
#include "utils/symbolic.h"
#include "utils/utils.h"
#include "pipeline/jit/base.h"
#include "utils/anf_utils.h"
namespace mindspore {
class ProtoExporter {
@ -399,6 +400,10 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodeP
return node->ToString();
}
if (AnfUtils::IsCustomActorNode(node)) {
return AnfUtils::GetCustomActorName(node);
}
if (node->isa<ValueNode>()) {
auto iter = const_map_ptr->find(node);
if (iter == const_map_ptr->end()) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -52,6 +52,7 @@ enum class KernelTransformType {
kDeviceDataSourceActor,
kHostDataSourceActor,
kKernelActor,
kCustomActor,
// Super kernel actor represents the sink executing of graph which is the combination of kernels.
kSuperKernelActor,
kCopyActor,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -152,6 +152,13 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
ofs << "\n";
}
void DumpCustomActor(const CustomActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
DumpAbstractActor(actor, ofs);
ofs << "\n";
}
void DumpSuperKernelActor(const SuperKernelActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
@ -448,6 +455,13 @@ void DumpKernelActors(const std::vector<KernelActorPtr> &actors, std::ofstream &
}
}
void DumpCustomActors(const std::vector<CustomActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n[Custom actors:" << actors.size() << "]\n";
for (const auto &custom_actor : actors) {
DumpCustomActor(custom_actor.get(), ofs);
}
}
void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::ofstream &ofs) {
ofs << "\n\n[Super kernel actors:" << actors.size() << "]\n";
for (const auto &super_kernel_actor : actors) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,6 +30,7 @@
#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/framework/actor/custom_actor.h"
#include "runtime/framework/actor/control_flow/control_actor.h"
#include "runtime/framework/actor/control_flow/switch_actor.h"
#include "runtime/framework/actor/control_flow/gather_actor.h"
@ -49,6 +50,7 @@ void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::
void DumpNoInputKernelActors(const std::vector<AbstractActorPtr> &actors, std::ofstream &ofs);
void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs);
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs);
void DumpCustomActors(const std::vector<CustomActorPtr> &actors, std::ofstream &ofs);
} // namespace runtime
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/custom_actor.h"
#include "runtime/framework/actor/super_kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/copy_actor.h"
@ -77,6 +78,7 @@ struct ActorSet {
DataPrepareActorPtr data_prepare_actor_{nullptr};
std::vector<DataSourceActorPtr> data_source_actors_;
std::vector<KernelActorPtr> kernel_actors_;
std::vector<CustomActorPtr> custom_actors_;
std::vector<SuperKernelActorPtr> super_kernel_actors_;
// No input kernel actors need be triggered specifically.
std::vector<AbstractActorPtr> no_input_kernel_actors_;

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/custom_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void CustomActor::Init() {}
void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
auto node = kernel_.lock();
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_ZERO("device_contexts_ size", device_contexts_.size());
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
try {
device_contexts_[0]->LaunchCustomFunc(node);
} catch (const std::exception &e) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {
MsException::Instance().SetException();
}
std::string error_info = "Launch custom kernel exception: " + node->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
}
SendOutput(ctx);
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include "utils/hash_map.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/debug_aware_actor.h"
#include "runtime/hardware/device_context.h"
#include "runtime/framework/device_tensor_store.h"
#include "backend/kernel_compiler/kernel.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/anf_utils.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::device::KernelInfo;
using mindspore::kernel::Address;
using mindspore::kernel::KernelLaunchInfo;
using mindspore::tensor::TensorPtr;
class CustomActor : public AbstractActor {
public:
CustomActor(const std::string &name, const AnfNodePtr &kernel, const DeviceContext *device_context,
const AID *recorder_aid)
: AbstractActor(name, KernelTransformType::kCustomActor, recorder_aid), kernel_(kernel) {
device_contexts_.push_back(device_context);
}
CustomActor(const std::string &name, const AnfNodePtr &kernel, const DeviceContext *device_context,
const AID *recorder_aid, GraphExecutionStrategy strategy)
: AbstractActor(name, KernelTransformType::kCustomActor, recorder_aid), kernel_(kernel), strategy_(strategy) {
device_contexts_.push_back(device_context);
}
~CustomActor() override = default;
void Init() override;
const AnfNodeWeakPtr &kernel() const { return kernel_; }
protected:
void Run(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;
friend class ControlNodeScheduler;
// The info of kernel.
AnfNodeWeakPtr kernel_;
AnfUtils::CustomActorCallback custom_func_ = {};
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
};
using CustomActorPtr = std::shared_ptr<CustomActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,7 @@
*/
#include "runtime/framework/graph_scheduler.h"
#include <queue>
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/debug_actor.h"
#include "runtime/framework/actor/recorder_actor.h"
@ -23,6 +24,7 @@
#include "mindrt/include/async/async.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
#include "utils/anf_utils.h"
#include "utils/config_manager.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils.h"
@ -75,6 +77,10 @@ std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(data_source_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
}
for (auto &custom_actor : actor_set->custom_actors_) {
MS_EXCEPTION_IF_NULL(custom_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
}
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
(void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
@ -486,6 +492,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
auto host_queue = std::make_shared<HostTensorQueue>();
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
actor_set->custom_actors_ = BuildCustomActor(graph_compiler_info);
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
@ -682,6 +689,32 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
return data_source_actors;
}
std::vector<CustomActorPtr> GraphScheduler::BuildCustomActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<CustomActorPtr> custom_actors;
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &device_context = graph_compiler_info.device_contexts_[i];
const auto &graph = graph_compiler_info.graphs_[i];
MS_EXCEPTION_IF_NULL(graph);
if (graph->is_executing_sink()) {
continue;
}
auto all_nodes = TopoSort(graph->get_return());
for (const auto &node : all_nodes) {
if (!AnfUtils::IsCustomActorNode(node)) {
continue;
}
auto actor_name = AnfUtils::GetCustomActorName(node);
auto custom_actor = std::make_shared<CustomActor>(actor_name, node, device_context, recorder_aid_);
MS_EXCEPTION_IF_NULL(custom_actor);
InsertActor(custom_actor.get());
custom_actors.emplace_back(custom_actor);
}
}
return custom_actors;
}
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<KernelActorPtr> kernel_actors;
@ -1388,6 +1421,87 @@ void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
graph_compiler_info.control_node_parser_);
// Link control arrows for custom actor
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
}
void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
const GraphCompilerInfo &graph_compiler_info) {
constexpr size_t kDependFromIdx = 2;
constexpr size_t kDependToIdx = 1;
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
// prepare for kernel => actor map
HashMap<AnfNodePtr, AbstractActorPtr> kernel_to_actors = {};
HashSet<AbstractActorPtr> no_depend_custom_actors = {};
for (const auto &actor : actor_set->custom_actors_) {
MS_EXCEPTION_IF_NULL(actor);
auto kernel = actor->kernel().lock();
MS_EXCEPTION_IF_NULL(kernel);
kernel_to_actors.emplace(kernel, actor);
no_depend_custom_actors.insert(actor);
}
for (const auto &actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(actor);
auto kernel = actor->kernel();
MS_EXCEPTION_IF_NULL(kernel);
kernel_to_actors.emplace(kernel, actor);
}
for (const auto &actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(actor);
auto device_data_source_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
if (device_data_source_actor != nullptr) {
auto kernel = device_data_source_actor->data_kernel();
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::GetCNodeName(kernel) == kGetNextOpName) {
kernel_to_actors.emplace(kernel, actor);
}
}
}
// find depend(custom, custom)
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
MS_EXCEPTION_IF_NULL(graph);
if (graph->is_executing_sink()) {
continue;
}
auto all_nodes = TopoSort(graph->get_return());
for (const auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
continue;
}
auto depend_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
MS_EXCEPTION_IF_CHECK_FAIL(depend_cnode->size() > kDependFromIdx,
"depend node " + depend_cnode->DebugString() + " input size " +
std::to_string(depend_cnode->size()) + " is invalid.");
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependFromIdx));
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependToIdx));
auto from_node = depend_cnode->input(kDependFromIdx);
auto to_node = depend_cnode->input(kDependToIdx);
if (!AnfUtils::IsCustomActorNode(from_node) && !AnfUtils::IsCustomActorNode(to_node)) {
continue;
}
auto from_iter = kernel_to_actors.find(from_node);
if (from_iter == kernel_to_actors.end()) {
MS_LOG(INFO) << from_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
continue;
}
auto to_iter = kernel_to_actors.find(to_node);
if (to_iter == kernel_to_actors.end()) {
MS_LOG(INFO) << to_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
continue;
}
AddControlArrow(from_iter->second.get(), to_iter->second.get());
no_depend_custom_actors.erase(std::dynamic_pointer_cast<CustomActor>(to_iter->second));
}
}
for (const auto &custom_actor : no_depend_custom_actors) {
AddControlArrow(actor_set->data_prepare_actor_.get(), custom_actor.get());
}
}
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
@ -1708,11 +1822,12 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
<< ", actual control num: " << actor->input_control_arrow_aids_.size();
}
if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->output_data_arrows_.size() == 0) &&
(actor->output_control_arrows_.size() == 0)) {
if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
(actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
}
if ((actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
(actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
(actor->input_controls_num_ == 0)) {
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
}
@ -1868,6 +1983,7 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
DumpOutputActor(actor_set->output_actor_, ofs);
DumpControlActors(actor_set->control_actors_, ofs);
DumpCustomActors(actor_set->custom_actors_, ofs);
}
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -96,6 +96,7 @@ class GraphScheduler {
std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
const HostTensorQueuePtr &host_queue);
std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<CustomActorPtr> BuildCustomActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
@ -152,6 +153,7 @@ class GraphScheduler {
void LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
const std::vector<AbstractActor *> &auto_monad_actors,
const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
const GraphCompilerInfo &graph_compiler_info);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,6 +58,7 @@
#ifndef ENABLE_SECURITY
#include "profiler/device/ascend/memory_profiling.h"
#include "runtime/device/ascend/profiling/profiling_manager.h"
#include "utils/anf_utils.h"
using Adx::AdxRegDumpProcessCallBack;
using mindspore::device::ascend::ProfilingManager;
@ -721,6 +722,14 @@ bool AscendDeviceContext::MemoryCopyAsync(const CNodePtr &node, const vector<Add
return true;
}
bool AscendDeviceContext::LaunchCustomFunc(const AnfNodePtr &kernel) const {
MS_EXCEPTION_IF_NULL(kernel);
auto custom_func = AnfUtils::GetCustomFunc(kernel);
BindDeviceToCurrentThread();
custom_func(nullptr);
return true;
}
void *AscendDeviceContext::GetKernelStream(const CNodePtr &node) const {
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -108,6 +108,8 @@ class AscendDeviceContext : public DeviceContext {
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
bool is_dynamic_shape = false) const override;
bool LaunchCustomFunc(const AnfNodePtr &kernel) const override;
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
// Devices that do not need stream could ignore the implementation of this function.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -133,6 +133,8 @@ class DeviceContext {
return true;
}
virtual bool LaunchCustomFunc(const AnfNodePtr &kernel) const { return true; }
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
// Devices that do not need stream could ignore the implementation of this function.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,7 @@
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "utils/ms_context.h"
#include "utils/anf_utils.h"
namespace mindspore {
bool ValueToBool(const ValuePtr &v, bool *value) {
@ -140,6 +141,9 @@ bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraph
MS_LOG(DEBUG) << "two parameters are not equal.";
return false;
}
if (AnfUtils::IsCustomActorNode(node1) && AnfUtils::IsCustomActorNode(node2)) {
return AnfUtils::IsCutomActorNodeSame(node1, node2);
}
if (node1->isa<CNode>() && node2->isa<CNode>()) {
return SameNode(node1, node2, equiv_func_graph, equiv_node);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#include "utils/anf_utils.h"
#include <map>
#include <memory>
#include <string>
#include "base/core_ops.h"
#include "utils/trace_base.h"
@ -41,6 +42,33 @@ class AbstractMutexManager {
std::map<const AnfNode *, std::recursive_mutex> mu_for_nodes_;
std::recursive_mutex mu_;
};
struct CustomActorInfo {
CustomActorInfo(AnfUtils::CustomActorCallback func, const std::string &base_node_name, const std::string &type_name,
bool is_fake = false, bool is_just_sync = false)
: actor_func(func),
base_node_name(base_node_name),
type_name(type_name),
is_fake(is_fake),
is_just_sync(is_just_sync) {}
~CustomActorInfo() = default;
// Key for user data.
constexpr static char key[] = "CustomActor";
AnfUtils::CustomActorCallback actor_func = {};
std::string base_node_name;
std::string type_name;
bool is_fake{false}; // For infer
bool is_just_sync{false}; // For update
};
using CustomActorInfoPtr = std::shared_ptr<CustomActorInfo>;
AnfNodePtr NewCustomActorNode(const CustomActorInfoPtr &actor_info, const FuncGraphPtr &g) {
MS_EXCEPTION_IF_NULL(g);
auto custom_actor_node = std::make_shared<AnfNode>(g);
custom_actor_node->set_user_data<CustomActorInfo>(actor_info);
return custom_actor_node;
}
} // namespace
AbstractScope::AbstractScope(std::recursive_mutex *mu) {
@ -296,6 +324,8 @@ std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node,
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (IsCustomActorNode(anf_node)) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@ -364,4 +394,95 @@ bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) {
}
return false;
}
bool AnfUtils::IsCustomActorNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->has_user_data<CustomActorInfo>();
}
bool AnfUtils::IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
if (!IsCustomActorNode(node1) || !IsCustomActorNode(node2)) {
MS_LOG(EXCEPTION) << "Two node are not all Custom Actor Node!";
}
auto actor_info1 = node1->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(actor_info1);
std::string actor_type1 = actor_info1->type_name;
bool is_fake1 = actor_info1->is_fake;
bool is_just_sync1 = actor_info1->is_just_sync;
auto actor_info2 = node2->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(actor_info2);
std::string actor_type2 = actor_info2->type_name;
bool is_fake2 = actor_info2->is_fake;
bool is_just_sync2 = actor_info2->is_just_sync;
return (actor_type1 == actor_type2) && (is_fake1 == is_fake2) && (is_just_sync1 == is_just_sync2);
}
std::string AnfUtils::GetCustomActorType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsCustomActorNode(node)) {
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
}
auto actor_info = node->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(actor_info);
return actor_info->type_name;
}
std::string AnfUtils::GetCustomActorName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsCustomActorNode(node)) {
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
}
auto actor_info = node->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(actor_info);
std::string actor_name = actor_info->type_name + "_of_" + actor_info->base_node_name;
return actor_name;
}
bool AnfUtils::GetCustomActorJustSyncFlag(const AnfNodePtr &node) {
if (!IsCustomActorNode(node)) {
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
}
auto update_info = node->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(update_info);
return update_info->is_just_sync;
}
AnfUtils::CustomActorCallback AnfUtils::GetCustomFunc(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsCustomActorNode(node)) {
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
}
auto actor_info = node->user_data<CustomActorInfo>();
MS_EXCEPTION_IF_NULL(actor_info);
return actor_info->actor_func;
}
AnfNodePtr AnfUtils::NewInitActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
MS_EXCEPTION_IF_NULL(base_cnode);
auto actor_info = std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kInit);
return NewCustomActorNode(actor_info, base_cnode->func_graph());
}
AnfNodePtr AnfUtils::NewInferActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode, bool is_fake) {
MS_EXCEPTION_IF_NULL(base_cnode);
auto actor_info = std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kInfer, is_fake);
return NewCustomActorNode(actor_info, base_cnode->func_graph());
}
AnfNodePtr AnfUtils::NewUpdateActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode,
bool is_just_sync) {
MS_EXCEPTION_IF_NULL(base_cnode);
auto actor_info =
std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kUpdate, false, is_just_sync);
return NewCustomActorNode(actor_info, base_cnode->func_graph());
}
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
#ifndef MINDSPORE_CORE_UTILS_ANF_UTILS_H_
#define MINDSPORE_CORE_UTILS_ANF_UTILS_H_
#include <functional>
#include <vector>
#include <string>
#include <utility>
@ -25,6 +26,10 @@
#include "ir/primitive.h"
namespace mindspore {
constexpr auto kInfer = "DS_Infer";
constexpr auto kInit = "DS_Init";
constexpr auto kUpdate = "DS_Update";
class AbstractScope {
public:
explicit AbstractScope(std::recursive_mutex *mu);
@ -40,6 +45,7 @@ class AbstractScope {
class AnfUtils {
public:
using CustomActorCallback = std::function<void(void *args)>;
static bool IsDimUnknown(const abstract::ShapePtr &shape);
static bool IsShapeDynamic(const abstract::ShapePtr &shape);
static bool IsShapeDynamic(const std::vector<size_t> &shape);
@ -66,6 +72,20 @@ class AnfUtils {
// Get dump flag from CNode's primitive.
static bool GetDumpFlag(const AnfNodePtr &node);
static AbstractScope GetAbstractLock(const AnfNode *node);
// Custom actor node is for dynamic shape.
// Generate a Init custom actor node.
static AnfNodePtr NewInitActorNode(CustomActorCallback f, const CNodePtr &base_cnode);
// Generate a Infer custom actor node. If `is_fake` is set to true, this node is a fake node without any infer action.
static AnfNodePtr NewInferActorNode(CustomActorCallback f, const CNodePtr &base_cnode, bool is_fake);
// Generate a Update custom actor node. If `is_just_sync` is set to true, this node is just for a stream-sync call.
static AnfNodePtr NewUpdateActorNode(CustomActorCallback f, const CNodePtr &base_cnode, bool is_just_sync);
static bool IsCustomActorNode(const AnfNodePtr &node);
static std::string GetCustomActorType(const AnfNodePtr &node);
static std::string GetCustomActorName(const AnfNodePtr &node);
static bool GetCustomActorJustSyncFlag(const AnfNodePtr &node);
static CustomActorCallback GetCustomFunc(const AnfNodePtr &node);
static bool IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_

View File

@ -0,0 +1,726 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h"
#include "common/backend_common_test.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#include "utils/utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kTupleFirstItemIndex = 0;
constexpr auto kTupleSecondItemIndex = 1;
constexpr auto kDependRealInputSize = 2;
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
const abstract::AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(g);
auto parameter = g->AddWeightParameter(name);
if (parameter == nullptr) {
MS_LOG(ERROR) << "Cannot add weight parameter!";
}
parameter->set_abstract(abstract);
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
return parameter;
}
CNodePtr TestCreateCNode(const KernelGraphPtr &g, const std::string &prim_name, const AnfNodePtrList &real_inputs,
const abstract::AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(g);
auto inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim_name))};
inputs.insert(inputs.end(), real_inputs.begin(), real_inputs.end());
auto cnode = g->NewCNode(inputs);
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cannot create cnode!";
}
cnode->set_abstract(abstract);
auto cnode_kernel_info = std::make_shared<device::KernelInfo>();
cnode_kernel_info->set_kernel_mod(std::make_shared<kernel::TbeKernelMod>(std::make_shared<kernel::KernelPack>()));
cnode->set_kernel_info(cnode_kernel_info);
return cnode;
}
inline abstract::AbstractTensorPtr TestCreateTensor(const TypePtr &element_type, const ShapeVector &shape) {
return std::make_shared<abstract::AbstractTensor>(element_type, shape);
}
inline abstract::AbstractTuplePtr TestCreateTupleTensor(const std::vector<TypePtr> &element_types,
const std::vector<ShapeVector> &shapes) {
if (element_types.size() != shapes.size()) {
MS_LOG(ERROR) << "Sizes for element type and shape are not match.";
}
AbstractBasePtrList abstract_list;
for (size_t i = 0; i < element_types.size(); ++i) {
abstract_list.emplace_back(TestCreateTensor(element_types[i], shapes[i]));
}
return std::make_shared<abstract::AbstractTuple>(abstract_list);
}
inline CNodePtr TestCreateDepend(const KernelGraphPtr &g, const AnfNodePtrList &inputs) {
MS_EXCEPTION_IF_NULL(g);
if (inputs.size() != kDependRealInputSize) {
MS_LOG(ERROR) << "Input size for Depnd should be 2!";
}
auto depend_node = g->NewCNode(std::vector<AnfNodePtr>{
NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), inputs[0], inputs[1]});
MS_EXCEPTION_IF_NULL(depend_node);
return depend_node;
}
inline CNodePtr TestCreateMakeTuple(const KernelGraphPtr &g, const AnfNodePtrList &inputs) {
MS_EXCEPTION_IF_NULL(g);
AbstractBasePtrList abstract_list;
for (const auto &input : inputs) {
abstract_list.emplace_back(input->abstract());
}
return TestCreateCNode(g, "MakeTuple", inputs, std::make_shared<abstract::AbstractTuple>(abstract_list));
}
inline void DumpGraph(const KernelGraphPtr &g) {
MS_EXCEPTION_IF_NULL(g);
std::string file_name = "debug_try_down_" + std::to_string(g->graph_id()) + ".ir";
DumpIR(file_name, g);
DumpIRProto(g, "try_down_" + std::to_string(g->graph_id()));
}
} // namespace
class TestDynamicShapePass : public BackendCommon {
public:
TestDynamicShapePass() {}
~TestDynamicShapePass() override = default;
};
/// Feature: Dynamic shape
/// Description: Dynamic op + inherited dynmiac op.
/// before:
/// a = Unique(%p) (1, -1)
/// |
/// b = A(a) (1, -1)
/// Expectation: Graph as following.
/// after:
/// Unique_Update Unique(%p) Unique_Init Unique_Infer
/// \ / | \ / \ / |
/// depend | depend depend |
/// | |
/// A A_Init A_Infer |
/// \ / \ / \ |
/// depend depend depend
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_0) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_a_node = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
before_fg->set_output(before_a_node);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_a_node = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
auto infer_a = dynamic_shape::GenInferNode(after_a_node);
auto init_a = dynamic_shape::GenInitNode(after_a_node);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_uniq});
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_uniq});
auto make_tuple = TestCreateMakeTuple(
after_fg, AnfNodePtrList{after_a_node, depend0, depend1, depend2, depend3, depend4, depend5, depend6});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_a_node->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: General op case.
/// before:
/// a = A(%p)
/// Expectation: Graph as following.
/// after:
/// A(%p) A_Init A_Infer
/// \ / \ /
/// depend depend
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_1) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_a_node =
TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
before_fg->set_output(before_a_node);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_a_node =
TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto infer_a = dynamic_shape::GenInferNode(after_a_node, true);
auto init_a = dynamic_shape::GenInitNode(after_a_node);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
auto make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_a_node, depend0, depend1});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_a_node->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: Get item and make tuple case.
/// before:
/// D0(%p)
/// | |
/// [0] [1]
/// | |
/// A |
/// \ /
/// MakeTuple
/// Expectation: Graph as following.
/// after:
/// D0_Update D0(%p) D0_Init D0_Infer
/// \ / | \ \ / \ / |
/// depend [0] [1] \ / Depend |
/// | | Depend |
/// A | A_Infer |
/// | | | \ |
/// | | | Depend
/// | | | D0_Update
/// | | | /
/// \ / Depend
/// MakeTuple
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_2) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
auto before_tuple = TestCreateCNode(
before_fg, "Unique", AnfNodePtrList{before_p},
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
auto before_first_item = TestCreateCNode(before_fg, "TupleGetItem",
AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_second_item = TestCreateCNode(
before_fg, "TupleGetItem", AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_a = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_first_item}, before_first_item->abstract());
auto before_make_tuple = TestCreateMakeTuple(before_fg, AnfNodePtrList{before_a, before_second_item});
before_fg->set_output(before_make_tuple);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_tuple = TestCreateCNode(
after_fg, "Unique", AnfNodePtrList{after_p},
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
auto after_first_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_second_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_a = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_first_item}, after_first_item->abstract());
auto after_make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_a, after_second_item});
auto infer_tuple = dynamic_shape::GenInferNode(after_tuple);
auto init_tuple = dynamic_shape::GenInitNode(after_tuple);
auto update_tuple = dynamic_shape::GenUpdateNode(after_tuple);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_tuple, infer_tuple});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_tuple, init_tuple});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_tuple, after_tuple});
auto infer_a = dynamic_shape::GenInferNode(after_a);
auto init_a = dynamic_shape::GenInitNode(after_a);
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a, init_a});
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_tuple});
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_tuple});
auto make_tuple = TestCreateMakeTuple(
after_fg, AnfNodePtrList{after_make_tuple, depend0, depend1, depend2, depend3, depend4, depend5, depend6});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_make_tuple->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: Complecate case case.
/// Expectation: Graph as expected.
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_3) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_p1 = TestCreateParameter(before_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_p2 = TestCreateParameter(before_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_a = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p2}, before_p2->abstract());
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
auto before_tuple = TestCreateCNode(
before_fg, "Unique", AnfNodePtrList{before_p1, before_a},
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
auto before_first_item = TestCreateCNode(before_fg, "TupleGetItem",
AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_second_item = TestCreateCNode(
before_fg, "TupleGetItem", AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_b = TestCreateCNode(before_fg, "B", AnfNodePtrList{before_first_item}, before_first_item->abstract());
auto before_c = TestCreateCNode(before_fg, "C", AnfNodePtrList{before_b}, before_b->abstract());
auto before_d = TestCreateCNode(before_fg, "D", AnfNodePtrList{before_second_item}, before_second_item->abstract());
// This Unique is used to present a single outputs operatoration instead of its origin meanning,
// and it take dynamic shape but general general shape.
auto before_dync_end = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_d},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_e = TestCreateCNode(before_fg, "E", AnfNodePtrList{before_dync_end}, before_dync_end->abstract());
auto before_make_tuple = TestCreateMakeTuple(before_fg, AnfNodePtrList{before_c, before_e});
before_fg->set_output(before_make_tuple);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p1 = TestCreateParameter(after_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_p2 = TestCreateParameter(after_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_a = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p2}, after_p2->abstract());
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
auto after_tuple = TestCreateCNode(
after_fg, "Unique", AnfNodePtrList{after_p1, after_a},
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
auto after_first_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_second_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_b = TestCreateCNode(after_fg, "B", AnfNodePtrList{after_first_item}, after_first_item->abstract());
auto after_c = TestCreateCNode(after_fg, "C", AnfNodePtrList{after_b}, after_b->abstract());
auto after_d = TestCreateCNode(after_fg, "D", AnfNodePtrList{after_second_item}, after_second_item->abstract());
// This Unique is used to present a single outputs operatoration instead of its origin meanning,
// and it take dynamic shape but general general shape.
auto after_dync_end = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_d},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_e = TestCreateCNode(after_fg, "E", AnfNodePtrList{after_dync_end}, after_dync_end->abstract());
auto after_make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_c, after_e});
auto infer_a = dynamic_shape::GenInferNode(after_a, true);
auto init_a = dynamic_shape::GenInitNode(after_a);
auto infer_tuple = dynamic_shape::GenInferNode(after_tuple);
auto init_tuple = dynamic_shape::GenInitNode(after_tuple);
auto update_tuple = dynamic_shape::GenUpdateNode(after_tuple);
auto infer_b = dynamic_shape::GenInferNode(after_b);
auto init_b = dynamic_shape::GenInitNode(after_b);
auto infer_c = dynamic_shape::GenInferNode(after_c);
auto init_c = dynamic_shape::GenInitNode(after_c);
auto infer_d = dynamic_shape::GenInferNode(after_d);
auto init_d = dynamic_shape::GenInitNode(after_d);
auto infer_de = dynamic_shape::GenInferNode(after_dync_end);
auto init_de = dynamic_shape::GenInitNode(after_dync_end);
auto update_de = dynamic_shape::GenUpdateNode(after_dync_end);
auto infer_e = dynamic_shape::GenInferNode(after_e, true);
auto init_e = dynamic_shape::GenInitNode(after_e);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_a, init_a});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{init_tuple, infer_tuple});
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{after_tuple, init_tuple});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{update_tuple, after_tuple});
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tuple, infer_a});
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{after_b, init_b});
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, infer_tuple});
auto depend9 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, update_tuple});
auto depend10 = TestCreateDepend(after_fg, AnfNodePtrList{init_c, infer_c});
auto depend11 = TestCreateDepend(after_fg, AnfNodePtrList{after_c, init_c});
auto depend12 = TestCreateDepend(after_fg, AnfNodePtrList{infer_c, infer_b});
auto depend13 = TestCreateDepend(after_fg, AnfNodePtrList{init_d, infer_d});
auto depend14 = TestCreateDepend(after_fg, AnfNodePtrList{after_d, init_d});
auto depend15 = TestCreateDepend(after_fg, AnfNodePtrList{infer_d, infer_tuple});
auto depend16 = TestCreateDepend(after_fg, AnfNodePtrList{infer_d, update_tuple});
auto depend17 = TestCreateDepend(after_fg, AnfNodePtrList{init_de, infer_de});
auto depend18 = TestCreateDepend(after_fg, AnfNodePtrList{after_dync_end, init_de});
auto depend19 = TestCreateDepend(after_fg, AnfNodePtrList{update_de, after_dync_end});
auto depend20 = TestCreateDepend(after_fg, AnfNodePtrList{infer_de, infer_d});
auto depend21 = TestCreateDepend(after_fg, AnfNodePtrList{init_e, infer_e});
auto depend22 = TestCreateDepend(after_fg, AnfNodePtrList{after_e, init_e});
auto depend23 = TestCreateDepend(after_fg, AnfNodePtrList{infer_e, infer_de});
auto depend24 = TestCreateDepend(after_fg, AnfNodePtrList{infer_e, update_de});
auto make_tuple = TestCreateMakeTuple(
after_fg,
AnfNodePtrList{after_make_tuple, depend0, depend1, depend2, depend3, depend4, depend5, depend6, depend7,
depend8, depend9, depend10, depend11, depend12, depend13, depend14, depend15, depend16,
depend17, depend18, depend19, depend20, depend21, depend22, depend23, depend24});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_make_tuple->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: Dynamic op + depend.
/// before:
/// a = Unique(%p) (1, -1)
/// |
/// b = A(a) (1, -1) B(%p) (1, -1)
/// \ /
/// depend
/// Expectation: Graph as following.
/// after:
/// Unique_Update Unique(%p) Unique_Init Unique_Infer
/// \ / | \ / \ / |
/// depend | depend depend |
/// | |
/// B A A_Init A_Infer |
/// \ / \ / \ / \ |
/// depend depend depend depend
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_with_depend) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_a_node = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_b_node =
TestCreateCNode(before_fg, "B", AnfNodePtrList{before_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_depend_node = TestCreateDepend(before_fg, AnfNodePtrList{before_a_node, before_b_node});
before_fg->set_output(before_depend_node);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_a_node = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_b_node =
TestCreateCNode(after_fg, "B", AnfNodePtrList{after_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_depend_node = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, after_b_node});
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
auto infer_a = dynamic_shape::GenInferNode(after_a_node);
auto init_a = dynamic_shape::GenInitNode(after_a_node);
auto infer_b = dynamic_shape::GenInferNode(after_b_node);
auto init_b = dynamic_shape::GenInitNode(after_b_node);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_uniq});
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_uniq});
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{after_b_node, init_b});
auto make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_depend_node, depend0, depend1, depend2, depend3,
depend4, depend5, depend6, depend7, depend8});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_a_node->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: Dynamic op + monad.
/// before:
/// u = kUMond()
/// / |
/// a = Assign(v, x1, u) -- c = UpdateState(u, a) -- d = Load(v, c)
/// \ /
/// e = UpdateState(c,d) -- Unique(%p)
/// \ /
/// depend
/// Expectation: Graph as following.
/// after:
/// u = kUMond()
/// / |
/// Assign_Infer Assign_Init Assign -- UpdateState(u, a) -- Load(v, c)
/// \ / \ / \ /
/// depend depend e = UpdateState(c,d) -- Unique(%p) Unique_Init Unique_Infer
/// \ / | \ / \ /
/// depend | depend depend
/// | Unique_Update
/// | /
/// depend
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_with_monad) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
auto before_v = TestCreateParameter(before_fg, "v", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_x1 = TestCreateParameter(before_fg, "x1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto before_u = NewValueNode(kUMonad);
auto before_assign = TestCreateCNode(before_fg, "Assign", AnfNodePtrList{before_v, before_x1, before_u},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_update_state_1 =
TestCreateCNode(before_fg, "UpdateState", AnfNodePtrList{before_u, before_assign}, kUMonad->ToAbstract());
auto before_load =
TestCreateCNode(before_fg, "Load", AnfNodePtrList{before_v, before_update_state_1}, before_v->abstract());
auto before_update_state_2 = TestCreateCNode(
before_fg, "UpdateState", AnfNodePtrList{before_update_state_1, before_load}, kUMonad->ToAbstract());
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_load},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto before_depend_node = TestCreateDepend(before_fg, AnfNodePtrList{before_uniq_node, before_update_state_2});
before_fg->set_output(before_depend_node);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_v = TestCreateParameter(after_fg, "v", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_x1 = TestCreateParameter(after_fg, "x1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
auto after_u = NewValueNode(kUMonad);
auto after_assign = TestCreateCNode(after_fg, "Assign", AnfNodePtrList{after_v, after_x1, after_u},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_update_state_1 =
TestCreateCNode(after_fg, "UpdateState", AnfNodePtrList{after_u, after_assign}, kUMonad->ToAbstract());
auto after_load =
TestCreateCNode(after_fg, "Load", AnfNodePtrList{after_v, after_update_state_1}, after_v->abstract());
auto after_update_state_2 =
TestCreateCNode(after_fg, "UpdateState", AnfNodePtrList{after_update_state_1, after_load}, kUMonad->ToAbstract());
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_load},
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
auto after_depend_node = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, after_update_state_2});
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
auto infer_assign = dynamic_shape::GenInferNode(after_assign);
auto init_assign = dynamic_shape::GenInitNode(after_assign);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_assign, infer_assign});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_assign, init_assign});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
auto make_tuple =
TestCreateMakeTuple(after_fg, AnfNodePtrList{after_depend_node, depend0, depend1, depend2, depend3, depend4});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_uniq_node->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
/// Feature: Dynamic shape
/// Description: Need sync case(contain op such as Tile...).
/// Expectation: Graph as expected.
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_sync) {
// construct before graph
auto before_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(before_fg != nullptr);
const auto &kTile = prim::kPrimTile->name();
auto before_p1 = TestCreateParameter(before_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto before_p2 = TestCreateParameter(before_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p1},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, -1}));
auto before_tile1_node = TestCreateCNode(before_fg, kTile, AnfNodePtrList{before_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto before_a_node =
TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p2}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto before_b_node =
TestCreateCNode(before_fg, "B", AnfNodePtrList{before_a_node}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto before_tile2_node = TestCreateCNode(before_fg, kTile, AnfNodePtrList{before_a_node, before_b_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto before_add_node = TestCreateCNode(before_fg, "Add", AnfNodePtrList{before_tile1_node, before_tile2_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
before_fg->set_output(before_add_node);
// run pass
AscendDynamicShapeConvert(before_fg);
// construct after graph
auto after_fg = std::make_shared<session::KernelGraph>();
ASSERT_TRUE(after_fg != nullptr);
auto after_p1 = TestCreateParameter(after_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto after_p2 = TestCreateParameter(after_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p1},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, -1}));
auto after_tile1_node = TestCreateCNode(after_fg, kTile, AnfNodePtrList{after_uniq_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto after_a_node =
TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p2}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto after_b_node =
TestCreateCNode(after_fg, "B", AnfNodePtrList{after_a_node}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
auto after_tile2_node = TestCreateCNode(after_fg, kTile, AnfNodePtrList{after_a_node, after_b_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto after_add_node = TestCreateCNode(after_fg, "Add", AnfNodePtrList{after_tile1_node, after_tile2_node},
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
auto infer_tile1 = dynamic_shape::GenInferNode(after_tile1_node);
auto init_tile1 = dynamic_shape::GenInitNode(after_tile1_node);
auto infer_a = dynamic_shape::GenInferNode(after_a_node, true);
auto init_a = dynamic_shape::GenInitNode(after_a_node);
auto infer_b = dynamic_shape::GenInferNode(after_b_node, true);
auto init_b = dynamic_shape::GenInitNode(after_b_node);
auto update_b = dynamic_shape::GenUpdateNode(after_b_node, true);
auto infer_tile2 = dynamic_shape::GenInferNode(after_tile2_node, true);
auto init_tile2 = dynamic_shape::GenInitNode(after_tile2_node);
auto infer_add = dynamic_shape::GenInferNode(after_add_node, true);
auto init_add = dynamic_shape::GenInitNode(after_add_node);
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_tile1, infer_tile1});
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_tile1_node, init_tile1});
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile1, infer_uniq});
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile1, update_uniq});
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
auto depend9 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
auto depend10 = TestCreateDepend(after_fg, AnfNodePtrList{after_b_node, init_b});
auto depend11 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, infer_a});
auto depend12 = TestCreateDepend(after_fg, AnfNodePtrList{init_tile2, infer_tile2});
auto depend13 = TestCreateDepend(after_fg, AnfNodePtrList{after_tile2_node, init_tile2});
auto depend14 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, infer_a});
auto depend15 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, infer_b});
auto depend16 = TestCreateDepend(after_fg, AnfNodePtrList{update_b, after_b_node});
auto depend17 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, update_b});
auto depend18 = TestCreateDepend(after_fg, AnfNodePtrList{init_add, infer_add});
auto depend19 = TestCreateDepend(after_fg, AnfNodePtrList{after_add_node, init_add});
auto depend20 = TestCreateDepend(after_fg, AnfNodePtrList{infer_add, infer_tile1});
auto depend21 = TestCreateDepend(after_fg, AnfNodePtrList{infer_add, infer_tile2});
auto make_tuple = TestCreateMakeTuple(
after_fg, AnfNodePtrList{after_add_node, depend0, depend1, depend2, depend3, depend4, depend5, depend6,
depend7, depend8, depend9, depend10, depend11, depend12, depend13, depend14,
depend15, depend16, depend17, depend18, depend19, depend20, depend21});
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
after_add_node->abstract());
after_fg->set_output(get_item);
// assert
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
}
} // namespace opt
} // namespace mindspore