Add Custom actor nodes and custom actor support for dynamic shape case
This commit is contained in:
parent
1b08f35ef1
commit
58f386fe2a
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -153,6 +153,10 @@
|
|||
#include "backend/optimizer/ascend/mindir/dynamic_reshape_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/mindir/all_to_all_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/mindir/neighbor_exchange_v2_unify_mindir.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_dynamic_op.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_general_op.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_inherited_dynamic_op.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/link_custom_op.h"
|
||||
#include "backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_compile.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -618,5 +622,34 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void AscendDynamicShapeConvert(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs) {
|
||||
std::string file_name =
|
||||
"hwopt_d_before_dynamic_shape_convert_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph);
|
||||
DumpIRProto(kernel_graph, "before_dynamic_shape_convert_hwopt_" + std::to_string(kernel_graph->graph_id()));
|
||||
}
|
||||
#endif
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto dynamic_shape_convert_pm = std::make_shared<opt::PassManager>("dynamic_shape_convert_pm");
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertDynamicOp>());
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertGeneralOp>());
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::ConvertInheritedDynamicOp>());
|
||||
dynamic_shape_convert_pm->AddPass(std::make_shared<opt::dynamic_shape::LinkCustomOp>());
|
||||
optimizer->AddPassManager(dynamic_shape_convert_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
if (save_graphs) {
|
||||
std::string file_name =
|
||||
"hwopt_d_after_dynamic_shape_convert_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_name, kernel_graph);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
|||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendDynamicShapeConvert(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
bool IsRealCNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n)) {
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(n);
|
||||
return AnfUtils::IsRealKernel(cnode);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace opt::dynamic_shape {
|
||||
bool IsGeneralOp(const BaseRef &n) {
|
||||
if (IsDynamicOp(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (IsInheritedDynamicOp(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return IsRealCNode(n);
|
||||
}
|
||||
|
||||
bool IsDynamicOp(const BaseRef &n) {
|
||||
if (!IsRealCNode(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
return kComputeDepend.find(op_name) != kComputeDepend.end();
|
||||
}
|
||||
|
||||
bool IsInheritedDynamicOp(const BaseRef &n) {
|
||||
if (IsDynamicOp(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsRealCNode(n)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr cnode = utils::cast<CNodePtr>(n);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return AnfAlgo::IsNodeInputDynamicShape(cnode) || AnfUtils::IsNodeOutputDynamicShape(cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr GenInferNode(const AnfNodePtr &node, bool fake_flag) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfUtils::CustomActorCallback actor_func;
|
||||
if (fake_flag) {
|
||||
actor_func = [](void *) -> void { return; };
|
||||
} else {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
actor_func = [kernel_mod](void *) { kernel_mod->InferOp(); };
|
||||
}
|
||||
|
||||
auto infer_node = AnfUtils::NewInferActorNode(actor_func, cnode, fake_flag);
|
||||
infer_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
return infer_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GenInitNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto init_node = AnfUtils::NewInitActorNode([kernel_mod](void *) { kernel_mod->InitOp(); }, cnode);
|
||||
init_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
return init_node;
|
||||
}
|
||||
|
||||
AnfNodePtr GenUpdateNode(const AnfNodePtr &node, bool just_sync_flag) {
|
||||
// Some not dynamic shape node should sync after launch for latter node.
|
||||
// Use a flag `just_sync_flag` to distinguish them with dynamic ones.
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto update_node =
|
||||
AnfUtils::NewUpdateActorNode([kernel_mod](void *) { kernel_mod->UpdateOp(); }, cnode, just_sync_flag);
|
||||
update_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
return update_node;
|
||||
}
|
||||
|
||||
bool IsDynUpdate(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto custom_actor_type = AnfUtils::GetCustomActorType(node);
|
||||
if (custom_actor_type != kUpdate) {
|
||||
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom update node!";
|
||||
}
|
||||
return !AnfUtils::GetCustomActorJustSyncFlag(node);
|
||||
}
|
||||
|
||||
CustomActorNodeManager &CustomActorNodeManager::Instance() {
|
||||
static CustomActorNodeManager instance{};
|
||||
return instance;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H
|
||||
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
bool IsGeneralOp(const BaseRef &n);
|
||||
bool IsDynamicOp(const BaseRef &n);
|
||||
bool IsInheritedDynamicOp(const BaseRef &n);
|
||||
AnfNodePtr GenInferNode(const AnfNodePtr &node, bool fake_flag = false);
|
||||
AnfNodePtr GenInitNode(const AnfNodePtr &node);
|
||||
AnfNodePtr GenUpdateNode(const AnfNodePtr &node, bool just_sync_flag = false);
|
||||
bool IsDynUpdate(const AnfNodePtr &node);
|
||||
|
||||
struct RelatedCustomActorNode {
|
||||
AnfNodePtr infer_node;
|
||||
AnfNodePtr init_node;
|
||||
AnfNodePtr update_node;
|
||||
};
|
||||
|
||||
class CustomActorNodeManager {
|
||||
public:
|
||||
static CustomActorNodeManager &Instance();
|
||||
void Reset() { custom_nodes_map_.clear(); }
|
||||
void Register(const AnfNodePtr &node, const RelatedCustomActorNode &custom_nodes) {
|
||||
custom_nodes_map_.emplace(node, custom_nodes);
|
||||
}
|
||||
bool IsRegistered(const AnfNodePtr &node) const { return custom_nodes_map_.find(node) != custom_nodes_map_.end(); }
|
||||
const RelatedCustomActorNode &GetCustomActorNodes(const AnfNodePtr &node) const {
|
||||
if (auto iter = custom_nodes_map_.find(node); iter != custom_nodes_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Not registered node!";
|
||||
}
|
||||
|
||||
private:
|
||||
CustomActorNodeManager() = default;
|
||||
~CustomActorNodeManager() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(CustomActorNodeManager)
|
||||
OrderedMap<AnfNodePtr, RelatedCustomActorNode> custom_nodes_map_;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_ASCEND_DYNAMIC_SHAPE_HELPER_H
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_dynamic_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
const BaseRef ConvertDynamicOp::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<CondVar>(IsDynamicOp);
|
||||
return BaseRef({X});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertDynamicOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto infer_node = GenInferNode(node);
|
||||
auto init_node = GenInitNode(node);
|
||||
auto update_node = GenUpdateNode(node);
|
||||
RelatedCustomActorNode custom_nodes = {infer_node, init_node, update_node};
|
||||
CustomActorNodeManager::Instance().Register(node, custom_nodes);
|
||||
return node;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
class ConvertDynamicOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvertDynamicOp(bool multigraph = true) : PatternProcessPass("convert_dynamic_op", multigraph) {}
|
||||
~ConvertDynamicOp() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_DYNAMIC_OP_H
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_general_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
const BaseRef ConvertGeneralOp::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<CondVar>(IsGeneralOp);
|
||||
return BaseRef({X});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertGeneralOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto stub_infer_node = GenInferNode(node, true);
|
||||
auto init_node = GenInitNode(node);
|
||||
auto sync_node = GenUpdateNode(node, true); // Use to call sync if needed.
|
||||
RelatedCustomActorNode custom_nodes = {stub_infer_node, init_node, sync_node};
|
||||
CustomActorNodeManager::Instance().Register(node, custom_nodes);
|
||||
return node;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
class ConvertGeneralOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvertGeneralOp(bool multigraph = true) : PatternProcessPass("convert_general_op", multigraph) {}
|
||||
~ConvertGeneralOp() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/dynamic_shape/convert_inherited_dynamic_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
const BaseRef ConvertInheritedDynamicOp::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<CondVar>(IsInheritedDynamicOp);
|
||||
return BaseRef({X});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertInheritedDynamicOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto infer_node = GenInferNode(node);
|
||||
auto init_node = GenInitNode(node);
|
||||
auto sync_node = GenUpdateNode(node, true); // Use to call sync if needed.
|
||||
RelatedCustomActorNode custom_nodes = {infer_node, init_node, sync_node};
|
||||
CustomActorNodeManager::Instance().Register(node, custom_nodes);
|
||||
return node;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_INHERITED_DYNAMIC_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_INHERITED_DYNAMIC_OP_H
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
class ConvertInheritedDynamicOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvertInheritedDynamicOp(bool multigraph = true)
|
||||
: PatternProcessPass("convert_inherited_dynamic_op", multigraph) {}
|
||||
~ConvertInheritedDynamicOp() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
|
|
@ -0,0 +1,212 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/dynamic_shape/link_custom_op.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt::dynamic_shape {
|
||||
namespace {
|
||||
constexpr size_t kTupleFirstItemIndex = 0;
|
||||
constexpr size_t kFirstDataInputIndex = 1;
|
||||
|
||||
AnfNodePtr InsertDepend(const FuncGraphPtr &g, const AnfNodePtr &prev, const AnfNodePtr &next) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
MS_EXCEPTION_IF_NULL(prev);
|
||||
MS_EXCEPTION_IF_NULL(next);
|
||||
// add depend from prev to next
|
||||
auto depend_node = g->NewCNode(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), next, prev});
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
return depend_node;
|
||||
}
|
||||
|
||||
bool LinkInternalOp(const FuncGraphPtr &g, const AnfNodePtr &node, AnfNodePtrList *depend_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(depend_nodes);
|
||||
bool changed = false;
|
||||
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(node);
|
||||
if (custom_nodes.infer_node != nullptr) {
|
||||
if (custom_nodes.init_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Node " << node->DebugString() << " has infer node but init node is null.";
|
||||
} else {
|
||||
depend_nodes->push_back(InsertDepend(g, custom_nodes.infer_node, custom_nodes.init_node)); // link infer => init
|
||||
depend_nodes->push_back(InsertDepend(g, custom_nodes.init_node, node)); // link init => launch
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (IsDynUpdate(custom_nodes.update_node)) {
|
||||
depend_nodes->push_back(InsertDepend(g, node, custom_nodes.update_node)); // link launch => update
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool LinkInputOp(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(depend_nodes);
|
||||
bool changed = false;
|
||||
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
|
||||
if (custom_nodes.infer_node == nullptr) {
|
||||
return changed;
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputNum(cnode);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto prev = AnfAlgo::GetPrevNodeOutput(cnode, i);
|
||||
const auto &prev_node = prev.first;
|
||||
if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
|
||||
continue;
|
||||
}
|
||||
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
|
||||
if (prev_custom_nodes.infer_node != nullptr) {
|
||||
depend_nodes->push_back(
|
||||
InsertDepend(g, prev_custom_nodes.infer_node, custom_nodes.infer_node)); // link prev.infer => curr.infer
|
||||
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " infer "
|
||||
<< prev_custom_nodes.infer_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
|
||||
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
|
||||
changed = true;
|
||||
}
|
||||
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
||||
depend_nodes->push_back(
|
||||
InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node)); // link prev.update => curr.infer
|
||||
MS_LOG(DEBUG) << "Link from " << prev_node->fullname_with_scope() << " update "
|
||||
<< prev_custom_nodes.update_node->fullname_with_scope() << " to " << cnode->fullname_with_scope()
|
||||
<< " infer " << custom_nodes.infer_node->fullname_with_scope();
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool LinkDependSync(const FuncGraphPtr &g, const CNodePtr &cnode, AnfNodePtrList *depend_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(depend_nodes);
|
||||
bool changed = false;
|
||||
auto custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(cnode);
|
||||
if (custom_nodes.infer_node == nullptr) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
auto dynamic_shape_depends = abstract::GetDependsFormMap(cnode);
|
||||
if (dynamic_shape_depends.empty()) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
for (auto depend_index : dynamic_shape_depends) {
|
||||
auto prev = AnfAlgo::GetPrevNodeOutput(cnode, depend_index);
|
||||
const auto &prev_node = prev.first;
|
||||
if (prev_node == nullptr || !CustomActorNodeManager::Instance().IsRegistered(prev_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If previous node is dynamic, so it was already link.
|
||||
auto prev_custom_nodes = CustomActorNodeManager::Instance().GetCustomActorNodes(prev_node);
|
||||
if (IsDynUpdate(prev_custom_nodes.update_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 1. Link prev_node => prev_node.update if its update is just sync.
|
||||
depend_nodes->push_back(InsertDepend(g, prev_node, prev_custom_nodes.update_node));
|
||||
// 2. Link prev_node.update => cur_node.infer.
|
||||
depend_nodes->push_back(InsertDepend(g, prev_custom_nodes.update_node, custom_nodes.infer_node));
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Attach Custom's Depend nodes with additional MakeTuple and TupleGetItem before graph return.
|
||||
*
|
||||
* %0 = A
|
||||
* return %0
|
||||
* ---->
|
||||
* %0 = A
|
||||
* %1 = MakeTuple(%0, %depend0, %depend1...)
|
||||
* %2 = TupleGetItem(%1, 0)
|
||||
* return %2
|
||||
*
|
||||
* @param g Graph.
|
||||
* @param depend_nodes Custom's Depend nodes.
|
||||
*/
|
||||
void AttachDependNodes(const FuncGraphPtr &g, const AnfNodePtrList &depend_nodes) {
|
||||
if (depend_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
auto return_node = g->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
|
||||
// New MakeTuple node
|
||||
auto mk_inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())),
|
||||
return_node->input(kFirstDataInputIndex)};
|
||||
mk_inputs.insert(mk_inputs.end(), depend_nodes.begin(), depend_nodes.end());
|
||||
auto make_tuple_node = g->NewCNode(mk_inputs);
|
||||
|
||||
// Get first element item form that maketuple and return.
|
||||
auto get_1st_item = g->NewCNode(AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim::kTupleGetItem)),
|
||||
make_tuple_node, NewValueNode(SizeToLong(kTupleFirstItemIndex))});
|
||||
|
||||
// Attach back.
|
||||
return_node->set_input(kFirstDataInputIndex, get_1st_item);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool LinkCustomOp::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
bool changed = false;
|
||||
AnfNodePtrList depend_nodes;
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (const auto &node : node_list) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !CustomActorNodeManager::Instance().IsRegistered(cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
changed = LinkInternalOp(func_graph, cnode, &depend_nodes) || changed;
|
||||
changed = LinkInputOp(func_graph, cnode, &depend_nodes) || changed;
|
||||
changed = LinkDependSync(func_graph, cnode, &depend_nodes) || changed;
|
||||
}
|
||||
|
||||
CustomActorNodeManager::Instance().Reset();
|
||||
|
||||
if (changed) {
|
||||
AttachDependNodes(func_graph, depend_nodes);
|
||||
|
||||
// Rebuild graph's edge.
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt::dynamic_shape
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_LINK_CUSTOM_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_LINK_CUSTOM_OP_H
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore::opt::dynamic_shape {
|
||||
class LinkCustomOp : public Pass {
|
||||
public:
|
||||
LinkCustomOp() : Pass("link_custom_op") {}
|
||||
~LinkCustomOp() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::opt::dynamic_shape
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_DYNAMIC_SHAPE_CONVERT_GENERAL_OP_H
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@
|
|||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -878,7 +879,7 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
|||
auto node = que.front();
|
||||
que.pop();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<Parameter>() || node->isa<ValueNode>()) {
|
||||
if (node->isa<Parameter>() || node->isa<ValueNode>() || AnfUtils::IsCustomActorNode(node)) {
|
||||
seed_nodes->push(node);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/jit/base.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
const std::string ToShortString(const TypeId &typeId) {
|
||||
|
@ -257,6 +258,32 @@ void DumpOperator(const AnfNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
|
|||
}
|
||||
}
|
||||
|
||||
void DumpParamterInOperand(const AnfNodePtr &node, const AnfNodePtr &in, OrderedMap<AnfNodePtr, int32_t> *para_map,
|
||||
const std::shared_ptr<SubGraphIRInfo> &gsub) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(in);
|
||||
MS_EXCEPTION_IF_NULL(para_map);
|
||||
MS_EXCEPTION_IF_NULL(gsub);
|
||||
if (in->func_graph() == nullptr) {
|
||||
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
|
||||
}
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
|
||||
} else {
|
||||
gsub->buffer << "%";
|
||||
}
|
||||
auto iter = para_map->find(in);
|
||||
if (iter == para_map->end()) {
|
||||
gsub->buffer << "para_" << in->ToString();
|
||||
} else {
|
||||
gsub->buffer << "para" << iter->second << "_" << in->ToString();
|
||||
}
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_map,
|
||||
const std::shared_ptr<SubGraphIRInfo> &gsub) {
|
||||
if (node == nullptr || para_map == nullptr || gsub == nullptr) {
|
||||
|
@ -275,24 +302,7 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
|
|||
gsub->buffer << ", ";
|
||||
}
|
||||
if (in->isa<Parameter>()) {
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
if (in->func_graph() == nullptr) {
|
||||
MS_LOG(ERROR) << "Parameter should belong to a func graph. Check func graph: " << node->func_graph();
|
||||
}
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << "$(@" << in->func_graph()->ToString() << ":";
|
||||
} else {
|
||||
gsub->buffer << "%";
|
||||
}
|
||||
auto iter = para_map->find(in);
|
||||
if (iter == para_map->end()) {
|
||||
gsub->buffer << "para_" << in->ToString();
|
||||
} else {
|
||||
gsub->buffer << "para" << iter->second << "_" << in->ToString();
|
||||
}
|
||||
if (in->func_graph() != nullptr && in->func_graph() != node->func_graph()) {
|
||||
gsub->buffer << ")";
|
||||
}
|
||||
DumpParamterInOperand(node, in, para_map, gsub);
|
||||
} else if (in->isa<CNode>()) {
|
||||
auto iter = gsub->local_var_map.find(in);
|
||||
if (iter != gsub->local_var_map.end()) {
|
||||
|
@ -308,6 +318,8 @@ void DumpOperands(const AnfNodePtr &node, OrderedMap<AnfNodePtr, int32_t> *para_
|
|||
} else if (IsValueNode<FuncGraph>(in)) {
|
||||
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(in);
|
||||
gsub->buffer << "@" << fg->ToString();
|
||||
} else if (AnfUtils::IsCustomActorNode(in)) {
|
||||
gsub->buffer << "%" << AnfUtils::GetCustomActorName(in);
|
||||
} else {
|
||||
gsub->buffer << in->ToString();
|
||||
}
|
||||
|
@ -570,6 +582,8 @@ void DumpIRInSubgraph(const std::vector<AnfNodePtr> &nodes, OrderedMap<AnfNodePt
|
|||
if (node->isa<CNode>()) {
|
||||
// Print and record output of operator if it is not 'Return'
|
||||
DumpCNode(node->cast<CNodePtr>(), sub_graph, para_map, gsub, dump_full_name, dump_location);
|
||||
} else if (AnfUtils::IsCustomActorNode(node)) {
|
||||
continue;
|
||||
} else {
|
||||
gsub->buffer << " " << node->ToString() << std::endl;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@
|
|||
#include "utils/symbolic.h"
|
||||
#include "utils/utils.h"
|
||||
#include "pipeline/jit/base.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ProtoExporter {
|
||||
|
@ -399,6 +400,10 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodeP
|
|||
return node->ToString();
|
||||
}
|
||||
|
||||
if (AnfUtils::IsCustomActorNode(node)) {
|
||||
return AnfUtils::GetCustomActorName(node);
|
||||
}
|
||||
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto iter = const_map_ptr->find(node);
|
||||
if (iter == const_map_ptr->end()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -52,6 +52,7 @@ enum class KernelTransformType {
|
|||
kDeviceDataSourceActor,
|
||||
kHostDataSourceActor,
|
||||
kKernelActor,
|
||||
kCustomActor,
|
||||
// Super kernel actor represents the sink executing of graph which is the combination of kernels.
|
||||
kSuperKernelActor,
|
||||
kCopyActor,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -152,6 +152,13 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
|
|||
ofs << "\n";
|
||||
}
|
||||
|
||||
void DumpCustomActor(const CustomActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
|
||||
DumpAbstractActor(actor, ofs);
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
void DumpSuperKernelActor(const SuperKernelActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
|
||||
|
@ -448,6 +455,13 @@ void DumpKernelActors(const std::vector<KernelActorPtr> &actors, std::ofstream &
|
|||
}
|
||||
}
|
||||
|
||||
void DumpCustomActors(const std::vector<CustomActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n[Custom actors:" << actors.size() << "]\n";
|
||||
for (const auto &custom_actor : actors) {
|
||||
DumpCustomActor(custom_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::ofstream &ofs) {
|
||||
ofs << "\n\n[Super kernel actors:" << actors.size() << "]\n";
|
||||
for (const auto &super_kernel_actor : actors) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@
|
|||
#include "runtime/framework/actor/super_kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
#include "runtime/framework/actor/custom_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/control_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/switch_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/gather_actor.h"
|
||||
|
@ -49,6 +50,7 @@ void DumpSuperKernelActors(const std::vector<SuperKernelActorPtr> &actors, std::
|
|||
void DumpNoInputKernelActors(const std::vector<AbstractActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpCopyActors(const std::vector<CopyActorPtr> &actors, std::ofstream &ofs);
|
||||
void DumpControlActors(const ControlActorSetPtr &control_actor_set, std::ofstream &ofs);
|
||||
void DumpCustomActors(const std::vector<CustomActorPtr> &actors, std::ofstream &ofs);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@
|
|||
#include "runtime/framework/actor/data_source_actor.h"
|
||||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/framework/actor/custom_actor.h"
|
||||
#include "runtime/framework/actor/super_kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
|
@ -77,6 +78,7 @@ struct ActorSet {
|
|||
DataPrepareActorPtr data_prepare_actor_{nullptr};
|
||||
std::vector<DataSourceActorPtr> data_source_actors_;
|
||||
std::vector<KernelActorPtr> kernel_actors_;
|
||||
std::vector<CustomActorPtr> custom_actors_;
|
||||
std::vector<SuperKernelActorPtr> super_kernel_actors_;
|
||||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<AbstractActorPtr> no_input_kernel_actors_;
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/custom_actor.h"
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void CustomActor::Init() {}
|
||||
|
||||
void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
||||
auto node = kernel_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_ZERO("device_contexts_ size", device_contexts_.size());
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
try {
|
||||
device_contexts_[0]->LaunchCustomFunc(node);
|
||||
} catch (const std::exception &e) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
std::string error_info = "Launch custom kernel exception: " + node->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info);
|
||||
}
|
||||
SendOutput(ctx);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/hash_map.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/debug_aware_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::device::KernelInfo;
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::KernelLaunchInfo;
|
||||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
class CustomActor : public AbstractActor {
|
||||
public:
|
||||
CustomActor(const std::string &name, const AnfNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID *recorder_aid)
|
||||
: AbstractActor(name, KernelTransformType::kCustomActor, recorder_aid), kernel_(kernel) {
|
||||
device_contexts_.push_back(device_context);
|
||||
}
|
||||
CustomActor(const std::string &name, const AnfNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID *recorder_aid, GraphExecutionStrategy strategy)
|
||||
: AbstractActor(name, KernelTransformType::kCustomActor, recorder_aid), kernel_(kernel), strategy_(strategy) {
|
||||
device_contexts_.push_back(device_context);
|
||||
}
|
||||
~CustomActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
const AnfNodeWeakPtr &kernel() const { return kernel_; }
|
||||
|
||||
protected:
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
friend class ControlNodeScheduler;
|
||||
|
||||
// The info of kernel.
|
||||
AnfNodeWeakPtr kernel_;
|
||||
AnfUtils::CustomActorCallback custom_func_ = {};
|
||||
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
|
||||
};
|
||||
|
||||
using CustomActorPtr = std::shared_ptr<CustomActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
#include <queue>
|
||||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "runtime/framework/actor/debug_actor.h"
|
||||
#include "runtime/framework/actor/recorder_actor.h"
|
||||
|
@ -23,6 +24,7 @@
|
|||
#include "mindrt/include/async/async.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils.h"
|
||||
|
@ -75,6 +77,10 @@ std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
|
|||
MS_EXCEPTION_IF_NULL(data_source_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
|
||||
}
|
||||
for (auto &custom_actor : actor_set->custom_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(custom_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
|
||||
}
|
||||
for (auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
(void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
|
||||
|
@ -486,6 +492,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
|
|||
|
||||
auto host_queue = std::make_shared<HostTensorQueue>();
|
||||
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
|
||||
actor_set->custom_actors_ = BuildCustomActor(graph_compiler_info);
|
||||
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
|
||||
actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
|
||||
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
|
||||
|
@ -682,6 +689,32 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
return data_source_actors;
|
||||
}
|
||||
|
||||
std::vector<CustomActorPtr> GraphScheduler::BuildCustomActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<CustomActorPtr> custom_actors;
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &device_context = graph_compiler_info.device_contexts_[i];
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->is_executing_sink()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto all_nodes = TopoSort(graph->get_return());
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!AnfUtils::IsCustomActorNode(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto actor_name = AnfUtils::GetCustomActorName(node);
|
||||
auto custom_actor = std::make_shared<CustomActor>(actor_name, node, device_context, recorder_aid_);
|
||||
MS_EXCEPTION_IF_NULL(custom_actor);
|
||||
InsertActor(custom_actor.get());
|
||||
custom_actors.emplace_back(custom_actor);
|
||||
}
|
||||
}
|
||||
return custom_actors;
|
||||
}
|
||||
|
||||
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<KernelActorPtr> kernel_actors;
|
||||
|
||||
|
@ -1388,6 +1421,87 @@ void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std
|
|||
|
||||
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
|
||||
graph_compiler_info.control_node_parser_);
|
||||
|
||||
// Link control arrows for custom actor
|
||||
LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
constexpr size_t kDependFromIdx = 2;
|
||||
constexpr size_t kDependToIdx = 1;
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
|
||||
// prepare for kernel => actor map
|
||||
HashMap<AnfNodePtr, AbstractActorPtr> kernel_to_actors = {};
|
||||
HashSet<AbstractActorPtr> no_depend_custom_actors = {};
|
||||
for (const auto &actor : actor_set->custom_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto kernel = actor->kernel().lock();
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
kernel_to_actors.emplace(kernel, actor);
|
||||
no_depend_custom_actors.insert(actor);
|
||||
}
|
||||
for (const auto &actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto kernel = actor->kernel();
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
kernel_to_actors.emplace(kernel, actor);
|
||||
}
|
||||
for (const auto &actor : actor_set->data_source_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto device_data_source_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
|
||||
if (device_data_source_actor != nullptr) {
|
||||
auto kernel = device_data_source_actor->data_kernel();
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) == kGetNextOpName) {
|
||||
kernel_to_actors.emplace(kernel, actor);
|
||||
}
|
||||
}
|
||||
}
|
||||
// find depend(custom, custom)
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (graph->is_executing_sink()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto all_nodes = TopoSort(graph->get_return());
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
auto depend_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(depend_cnode->size() > kDependFromIdx,
|
||||
"depend node " + depend_cnode->DebugString() + " input size " +
|
||||
std::to_string(depend_cnode->size()) + " is invalid.");
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependFromIdx));
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependToIdx));
|
||||
auto from_node = depend_cnode->input(kDependFromIdx);
|
||||
auto to_node = depend_cnode->input(kDependToIdx);
|
||||
if (!AnfUtils::IsCustomActorNode(from_node) && !AnfUtils::IsCustomActorNode(to_node)) {
|
||||
continue;
|
||||
}
|
||||
auto from_iter = kernel_to_actors.find(from_node);
|
||||
if (from_iter == kernel_to_actors.end()) {
|
||||
MS_LOG(INFO) << from_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
|
||||
continue;
|
||||
}
|
||||
auto to_iter = kernel_to_actors.find(to_node);
|
||||
if (to_iter == kernel_to_actors.end()) {
|
||||
MS_LOG(INFO) << to_node->fullname_with_scope() << " is a CNode but cannot find Actor.";
|
||||
continue;
|
||||
}
|
||||
AddControlArrow(from_iter->second.get(), to_iter->second.get());
|
||||
no_depend_custom_actors.erase(std::dynamic_pointer_cast<CustomActor>(to_iter->second));
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &custom_actor : no_depend_custom_actors) {
|
||||
AddControlArrow(actor_set->data_prepare_actor_.get(), custom_actor.get());
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
|
||||
|
@ -1708,11 +1822,12 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
|
|||
<< ", actual control num: " << actor->input_control_arrow_aids_.size();
|
||||
}
|
||||
|
||||
if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->output_data_arrows_.size() == 0) &&
|
||||
(actor->output_control_arrows_.size() == 0)) {
|
||||
if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
|
||||
(actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
|
||||
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
|
||||
}
|
||||
if ((actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
|
||||
if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
|
||||
(actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
|
||||
(actor->input_controls_num_ == 0)) {
|
||||
MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
|
||||
}
|
||||
|
@ -1868,6 +1983,7 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
|
|||
DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
|
||||
DumpOutputActor(actor_set->output_actor_, ofs);
|
||||
DumpControlActors(actor_set->control_actors_, ofs);
|
||||
DumpCustomActors(actor_set->custom_actors_, ofs);
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -96,6 +96,7 @@ class GraphScheduler {
|
|||
std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const HostTensorQueuePtr &host_queue);
|
||||
std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<CustomActorPtr> BuildCustomActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<SuperKernelActorPtr> BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
@ -152,6 +153,7 @@ class GraphScheduler {
|
|||
void LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
|
||||
const std::vector<AbstractActor *> &auto_monad_actors,
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForCustomActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
|
||||
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,6 +58,7 @@
|
|||
#ifndef ENABLE_SECURITY
|
||||
#include "profiler/device/ascend/memory_profiling.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_manager.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
using Adx::AdxRegDumpProcessCallBack;
|
||||
using mindspore::device::ascend::ProfilingManager;
|
||||
|
@ -721,6 +722,14 @@ bool AscendDeviceContext::MemoryCopyAsync(const CNodePtr &node, const vector<Add
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AscendDeviceContext::LaunchCustomFunc(const AnfNodePtr &kernel) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto custom_func = AnfUtils::GetCustomFunc(kernel);
|
||||
BindDeviceToCurrentThread();
|
||||
custom_func(nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
void *AscendDeviceContext::GetKernelStream(const CNodePtr &node) const {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -108,6 +108,8 @@ class AscendDeviceContext : public DeviceContext {
|
|||
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs,
|
||||
bool is_dynamic_shape = false) const override;
|
||||
|
||||
bool LaunchCustomFunc(const AnfNodePtr &kernel) const override;
|
||||
|
||||
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
|
||||
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
|
||||
// Devices that do not need stream could ignore the implementation of this function.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -133,6 +133,8 @@ class DeviceContext {
|
|||
return true;
|
||||
}
|
||||
|
||||
virtual bool LaunchCustomFunc(const AnfNodePtr &kernel) const { return true; }
|
||||
|
||||
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
|
||||
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
|
||||
// Devices that do not need stream could ignore the implementation of this function.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,6 +27,7 @@
|
|||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
bool ValueToBool(const ValuePtr &v, bool *value) {
|
||||
|
@ -140,6 +141,9 @@ bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraph
|
|||
MS_LOG(DEBUG) << "two parameters are not equal.";
|
||||
return false;
|
||||
}
|
||||
if (AnfUtils::IsCustomActorNode(node1) && AnfUtils::IsCustomActorNode(node2)) {
|
||||
return AnfUtils::IsCutomActorNodeSame(node1, node2);
|
||||
}
|
||||
if (node1->isa<CNode>() && node2->isa<CNode>()) {
|
||||
return SameNode(node1, node2, equiv_func_graph, equiv_node);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "utils/anf_utils.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
@ -41,6 +42,33 @@ class AbstractMutexManager {
|
|||
std::map<const AnfNode *, std::recursive_mutex> mu_for_nodes_;
|
||||
std::recursive_mutex mu_;
|
||||
};
|
||||
|
||||
struct CustomActorInfo {
|
||||
CustomActorInfo(AnfUtils::CustomActorCallback func, const std::string &base_node_name, const std::string &type_name,
|
||||
bool is_fake = false, bool is_just_sync = false)
|
||||
: actor_func(func),
|
||||
base_node_name(base_node_name),
|
||||
type_name(type_name),
|
||||
is_fake(is_fake),
|
||||
is_just_sync(is_just_sync) {}
|
||||
~CustomActorInfo() = default;
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "CustomActor";
|
||||
AnfUtils::CustomActorCallback actor_func = {};
|
||||
std::string base_node_name;
|
||||
std::string type_name;
|
||||
bool is_fake{false}; // For infer
|
||||
bool is_just_sync{false}; // For update
|
||||
};
|
||||
using CustomActorInfoPtr = std::shared_ptr<CustomActorInfo>;
|
||||
|
||||
AnfNodePtr NewCustomActorNode(const CustomActorInfoPtr &actor_info, const FuncGraphPtr &g) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
auto custom_actor_node = std::make_shared<AnfNode>(g);
|
||||
custom_actor_node->set_user_data<CustomActorInfo>(actor_info);
|
||||
return custom_actor_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractScope::AbstractScope(std::recursive_mutex *mu) {
|
||||
|
@ -296,6 +324,8 @@ std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node,
|
|||
return std::make_pair(anf_node, 0);
|
||||
} else if (anf_node->isa<Parameter>()) {
|
||||
return std::make_pair(anf_node, 0);
|
||||
} else if (IsCustomActorNode(anf_node)) {
|
||||
return std::make_pair(anf_node, 0);
|
||||
} else if (anf_node->isa<CNode>()) {
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -364,4 +394,95 @@ bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) {
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AnfUtils::IsCustomActorNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return node->has_user_data<CustomActorInfo>();
|
||||
}
|
||||
|
||||
bool AnfUtils::IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||
MS_EXCEPTION_IF_NULL(node1);
|
||||
MS_EXCEPTION_IF_NULL(node2);
|
||||
if (!IsCustomActorNode(node1) || !IsCustomActorNode(node2)) {
|
||||
MS_LOG(EXCEPTION) << "Two node are not all Custom Actor Node!";
|
||||
}
|
||||
|
||||
auto actor_info1 = node1->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(actor_info1);
|
||||
std::string actor_type1 = actor_info1->type_name;
|
||||
bool is_fake1 = actor_info1->is_fake;
|
||||
bool is_just_sync1 = actor_info1->is_just_sync;
|
||||
|
||||
auto actor_info2 = node2->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(actor_info2);
|
||||
std::string actor_type2 = actor_info2->type_name;
|
||||
bool is_fake2 = actor_info2->is_fake;
|
||||
bool is_just_sync2 = actor_info2->is_just_sync;
|
||||
|
||||
return (actor_type1 == actor_type2) && (is_fake1 == is_fake2) && (is_just_sync1 == is_just_sync2);
|
||||
}
|
||||
|
||||
std::string AnfUtils::GetCustomActorType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsCustomActorNode(node)) {
|
||||
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
|
||||
}
|
||||
|
||||
auto actor_info = node->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(actor_info);
|
||||
return actor_info->type_name;
|
||||
}
|
||||
|
||||
std::string AnfUtils::GetCustomActorName(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsCustomActorNode(node)) {
|
||||
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
|
||||
}
|
||||
|
||||
auto actor_info = node->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(actor_info);
|
||||
std::string actor_name = actor_info->type_name + "_of_" + actor_info->base_node_name;
|
||||
return actor_name;
|
||||
}
|
||||
|
||||
bool AnfUtils::GetCustomActorJustSyncFlag(const AnfNodePtr &node) {
|
||||
if (!IsCustomActorNode(node)) {
|
||||
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
|
||||
}
|
||||
|
||||
auto update_info = node->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(update_info);
|
||||
return update_info->is_just_sync;
|
||||
}
|
||||
|
||||
AnfUtils::CustomActorCallback AnfUtils::GetCustomFunc(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsCustomActorNode(node)) {
|
||||
MS_LOG(EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
|
||||
}
|
||||
|
||||
auto actor_info = node->user_data<CustomActorInfo>();
|
||||
MS_EXCEPTION_IF_NULL(actor_info);
|
||||
return actor_info->actor_func;
|
||||
}
|
||||
|
||||
AnfNodePtr AnfUtils::NewInitActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(base_cnode);
|
||||
auto actor_info = std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kInit);
|
||||
return NewCustomActorNode(actor_info, base_cnode->func_graph());
|
||||
}
|
||||
|
||||
AnfNodePtr AnfUtils::NewInferActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode, bool is_fake) {
|
||||
MS_EXCEPTION_IF_NULL(base_cnode);
|
||||
auto actor_info = std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kInfer, is_fake);
|
||||
return NewCustomActorNode(actor_info, base_cnode->func_graph());
|
||||
}
|
||||
|
||||
AnfNodePtr AnfUtils::NewUpdateActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode,
|
||||
bool is_just_sync) {
|
||||
MS_EXCEPTION_IF_NULL(base_cnode);
|
||||
auto actor_info =
|
||||
std::make_shared<CustomActorInfo>(f, base_cnode->fullname_with_scope(), kUpdate, false, is_just_sync);
|
||||
return NewCustomActorNode(actor_info, base_cnode->func_graph());
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,7 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_UTILS_ANF_UTILS_H_
|
||||
#define MINDSPORE_CORE_UTILS_ANF_UTILS_H_
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -25,6 +26,10 @@
|
|||
#include "ir/primitive.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kInfer = "DS_Infer";
|
||||
constexpr auto kInit = "DS_Init";
|
||||
constexpr auto kUpdate = "DS_Update";
|
||||
|
||||
class AbstractScope {
|
||||
public:
|
||||
explicit AbstractScope(std::recursive_mutex *mu);
|
||||
|
@ -40,6 +45,7 @@ class AbstractScope {
|
|||
|
||||
class AnfUtils {
|
||||
public:
|
||||
using CustomActorCallback = std::function<void(void *args)>;
|
||||
static bool IsDimUnknown(const abstract::ShapePtr &shape);
|
||||
static bool IsShapeDynamic(const abstract::ShapePtr &shape);
|
||||
static bool IsShapeDynamic(const std::vector<size_t> &shape);
|
||||
|
@ -66,6 +72,20 @@ class AnfUtils {
|
|||
// Get dump flag from CNode's primitive.
|
||||
static bool GetDumpFlag(const AnfNodePtr &node);
|
||||
static AbstractScope GetAbstractLock(const AnfNode *node);
|
||||
|
||||
// Custom actor node is for dynamic shape.
|
||||
// Generate a Init custom actor node.
|
||||
static AnfNodePtr NewInitActorNode(CustomActorCallback f, const CNodePtr &base_cnode);
|
||||
// Generate a Infer custom actor node. If `is_fake` is set to true, this node is a fake node without any infer action.
|
||||
static AnfNodePtr NewInferActorNode(CustomActorCallback f, const CNodePtr &base_cnode, bool is_fake);
|
||||
// Generate a Update custom actor node. If `is_just_sync` is set to true, this node is just for a stream-sync call.
|
||||
static AnfNodePtr NewUpdateActorNode(CustomActorCallback f, const CNodePtr &base_cnode, bool is_just_sync);
|
||||
static bool IsCustomActorNode(const AnfNodePtr &node);
|
||||
static std::string GetCustomActorType(const AnfNodePtr &node);
|
||||
static std::string GetCustomActorName(const AnfNodePtr &node);
|
||||
static bool GetCustomActorJustSyncFlag(const AnfNodePtr &node);
|
||||
static CustomActorCallback GetCustomFunc(const AnfNodePtr &node);
|
||||
static bool IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,726 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/ascend/dynamic_shape/ascend_dynamic_shape_helper.h"
|
||||
#include "backend/optimizer/ascend/ascend_backend_optimization.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h"
|
||||
#include "common/backend_common_test.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kTupleFirstItemIndex = 0;
|
||||
constexpr auto kTupleSecondItemIndex = 1;
|
||||
constexpr auto kDependRealInputSize = 2;
|
||||
|
||||
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
|
||||
const abstract::AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
auto parameter = g->AddWeightParameter(name);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot add weight parameter!";
|
||||
}
|
||||
parameter->set_abstract(abstract);
|
||||
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
return parameter;
|
||||
}
|
||||
|
||||
CNodePtr TestCreateCNode(const KernelGraphPtr &g, const std::string &prim_name, const AnfNodePtrList &real_inputs,
|
||||
const abstract::AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
auto inputs = AnfNodePtrList{NewValueNode(std::make_shared<Primitive>(prim_name))};
|
||||
inputs.insert(inputs.end(), real_inputs.begin(), real_inputs.end());
|
||||
auto cnode = g->NewCNode(inputs);
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot create cnode!";
|
||||
}
|
||||
cnode->set_abstract(abstract);
|
||||
auto cnode_kernel_info = std::make_shared<device::KernelInfo>();
|
||||
cnode_kernel_info->set_kernel_mod(std::make_shared<kernel::TbeKernelMod>(std::make_shared<kernel::KernelPack>()));
|
||||
cnode->set_kernel_info(cnode_kernel_info);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
inline abstract::AbstractTensorPtr TestCreateTensor(const TypePtr &element_type, const ShapeVector &shape) {
|
||||
return std::make_shared<abstract::AbstractTensor>(element_type, shape);
|
||||
}
|
||||
|
||||
inline abstract::AbstractTuplePtr TestCreateTupleTensor(const std::vector<TypePtr> &element_types,
|
||||
const std::vector<ShapeVector> &shapes) {
|
||||
if (element_types.size() != shapes.size()) {
|
||||
MS_LOG(ERROR) << "Sizes for element type and shape are not match.";
|
||||
}
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t i = 0; i < element_types.size(); ++i) {
|
||||
abstract_list.emplace_back(TestCreateTensor(element_types[i], shapes[i]));
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
}
|
||||
|
||||
inline CNodePtr TestCreateDepend(const KernelGraphPtr &g, const AnfNodePtrList &inputs) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
if (inputs.size() != kDependRealInputSize) {
|
||||
MS_LOG(ERROR) << "Input size for Depnd should be 2!";
|
||||
}
|
||||
|
||||
auto depend_node = g->NewCNode(std::vector<AnfNodePtr>{
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), inputs[0], inputs[1]});
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
return depend_node;
|
||||
}
|
||||
|
||||
inline CNodePtr TestCreateMakeTuple(const KernelGraphPtr &g, const AnfNodePtrList &inputs) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (const auto &input : inputs) {
|
||||
abstract_list.emplace_back(input->abstract());
|
||||
}
|
||||
return TestCreateCNode(g, "MakeTuple", inputs, std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
}
|
||||
|
||||
inline void DumpGraph(const KernelGraphPtr &g) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
std::string file_name = "debug_try_down_" + std::to_string(g->graph_id()) + ".ir";
|
||||
DumpIR(file_name, g);
|
||||
DumpIRProto(g, "try_down_" + std::to_string(g->graph_id()));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class TestDynamicShapePass : public BackendCommon {
|
||||
public:
|
||||
TestDynamicShapePass() {}
|
||||
~TestDynamicShapePass() override = default;
|
||||
};
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Dynamic op + inherited dynmiac op.
|
||||
/// before:
|
||||
/// a = Unique(%p) (1, -1)
|
||||
/// |
|
||||
/// b = A(a) (1, -1)
|
||||
/// Expectation: Graph as following.
|
||||
/// after:
|
||||
/// Unique_Update Unique(%p) Unique_Init Unique_Infer
|
||||
/// \ / | \ / \ / |
|
||||
/// depend | depend depend |
|
||||
/// | |
|
||||
/// A A_Init A_Infer |
|
||||
/// \ / \ / \ |
|
||||
/// depend depend depend
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_0) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_a_node = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
before_fg->set_output(before_a_node);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_a_node = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
|
||||
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
|
||||
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
|
||||
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a_node);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a_node);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
|
||||
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_uniq});
|
||||
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_uniq});
|
||||
auto make_tuple = TestCreateMakeTuple(
|
||||
after_fg, AnfNodePtrList{after_a_node, depend0, depend1, depend2, depend3, depend4, depend5, depend6});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_a_node->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: General op case.
|
||||
/// before:
|
||||
/// a = A(%p)
|
||||
/// Expectation: Graph as following.
|
||||
/// after:
|
||||
/// A(%p) A_Init A_Infer
|
||||
/// \ / \ /
|
||||
/// depend depend
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_1) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_a_node =
|
||||
TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
before_fg->set_output(before_a_node);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_a_node =
|
||||
TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a_node, true);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a_node);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
|
||||
|
||||
auto make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_a_node, depend0, depend1});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_a_node->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Get item and make tuple case.
|
||||
/// before:
|
||||
/// D0(%p)
|
||||
/// | |
|
||||
/// [0] [1]
|
||||
/// | |
|
||||
/// A |
|
||||
/// \ /
|
||||
/// MakeTuple
|
||||
/// Expectation: Graph as following.
|
||||
/// after:
|
||||
/// D0_Update D0(%p) D0_Init D0_Infer
|
||||
/// \ / | \ \ / \ / |
|
||||
/// depend [0] [1] \ / Depend |
|
||||
/// | | Depend |
|
||||
/// A | A_Infer |
|
||||
/// | | | \ |
|
||||
/// | | | Depend
|
||||
/// | | | D0_Update
|
||||
/// | | | /
|
||||
/// \ / Depend
|
||||
/// MakeTuple
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_2) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
|
||||
auto before_tuple = TestCreateCNode(
|
||||
before_fg, "Unique", AnfNodePtrList{before_p},
|
||||
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
|
||||
auto before_first_item = TestCreateCNode(before_fg, "TupleGetItem",
|
||||
AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_second_item = TestCreateCNode(
|
||||
before_fg, "TupleGetItem", AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_a = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_first_item}, before_first_item->abstract());
|
||||
auto before_make_tuple = TestCreateMakeTuple(before_fg, AnfNodePtrList{before_a, before_second_item});
|
||||
before_fg->set_output(before_make_tuple);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
|
||||
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_tuple = TestCreateCNode(
|
||||
after_fg, "Unique", AnfNodePtrList{after_p},
|
||||
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
|
||||
auto after_first_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_second_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_a = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_first_item}, after_first_item->abstract());
|
||||
auto after_make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_a, after_second_item});
|
||||
|
||||
auto infer_tuple = dynamic_shape::GenInferNode(after_tuple);
|
||||
auto init_tuple = dynamic_shape::GenInitNode(after_tuple);
|
||||
auto update_tuple = dynamic_shape::GenUpdateNode(after_tuple);
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_tuple, infer_tuple});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_tuple, init_tuple});
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_tuple, after_tuple});
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a);
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a, init_a});
|
||||
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_tuple});
|
||||
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_tuple});
|
||||
|
||||
auto make_tuple = TestCreateMakeTuple(
|
||||
after_fg, AnfNodePtrList{after_make_tuple, depend0, depend1, depend2, depend3, depend4, depend5, depend6});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_make_tuple->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Complecate case case.
|
||||
/// Expectation: Graph as expected.
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_3) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_p1 = TestCreateParameter(before_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_p2 = TestCreateParameter(before_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_a = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p2}, before_p2->abstract());
|
||||
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
|
||||
auto before_tuple = TestCreateCNode(
|
||||
before_fg, "Unique", AnfNodePtrList{before_p1, before_a},
|
||||
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
|
||||
auto before_first_item = TestCreateCNode(before_fg, "TupleGetItem",
|
||||
AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_second_item = TestCreateCNode(
|
||||
before_fg, "TupleGetItem", AnfNodePtrList{before_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
|
||||
auto before_b = TestCreateCNode(before_fg, "B", AnfNodePtrList{before_first_item}, before_first_item->abstract());
|
||||
auto before_c = TestCreateCNode(before_fg, "C", AnfNodePtrList{before_b}, before_b->abstract());
|
||||
|
||||
auto before_d = TestCreateCNode(before_fg, "D", AnfNodePtrList{before_second_item}, before_second_item->abstract());
|
||||
// This Unique is used to present a single outputs operatoration instead of its origin meanning,
|
||||
// and it take dynamic shape but general general shape.
|
||||
auto before_dync_end = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_d},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_e = TestCreateCNode(before_fg, "E", AnfNodePtrList{before_dync_end}, before_dync_end->abstract());
|
||||
|
||||
auto before_make_tuple = TestCreateMakeTuple(before_fg, AnfNodePtrList{before_c, before_e});
|
||||
before_fg->set_output(before_make_tuple);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
|
||||
auto after_p1 = TestCreateParameter(after_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_p2 = TestCreateParameter(after_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_a = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p2}, after_p2->abstract());
|
||||
// This Unique is used to present a multiply outputs operatoration instead of its origin meanning.
|
||||
auto after_tuple = TestCreateCNode(
|
||||
after_fg, "Unique", AnfNodePtrList{after_p1, after_a},
|
||||
TestCreateTupleTensor(std::vector<TypePtr>(2, kFloat32), std::vector<ShapeVector>(2, std::vector<int64_t>{1, -1})));
|
||||
auto after_first_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_second_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{after_tuple, NewValueNode(SizeToLong(kTupleSecondItemIndex))},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
|
||||
auto after_b = TestCreateCNode(after_fg, "B", AnfNodePtrList{after_first_item}, after_first_item->abstract());
|
||||
auto after_c = TestCreateCNode(after_fg, "C", AnfNodePtrList{after_b}, after_b->abstract());
|
||||
|
||||
auto after_d = TestCreateCNode(after_fg, "D", AnfNodePtrList{after_second_item}, after_second_item->abstract());
|
||||
// This Unique is used to present a single outputs operatoration instead of its origin meanning,
|
||||
// and it take dynamic shape but general general shape.
|
||||
auto after_dync_end = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_d},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_e = TestCreateCNode(after_fg, "E", AnfNodePtrList{after_dync_end}, after_dync_end->abstract());
|
||||
|
||||
auto after_make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_c, after_e});
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a, true);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a);
|
||||
|
||||
auto infer_tuple = dynamic_shape::GenInferNode(after_tuple);
|
||||
auto init_tuple = dynamic_shape::GenInitNode(after_tuple);
|
||||
auto update_tuple = dynamic_shape::GenUpdateNode(after_tuple);
|
||||
|
||||
auto infer_b = dynamic_shape::GenInferNode(after_b);
|
||||
auto init_b = dynamic_shape::GenInitNode(after_b);
|
||||
|
||||
auto infer_c = dynamic_shape::GenInferNode(after_c);
|
||||
auto init_c = dynamic_shape::GenInitNode(after_c);
|
||||
|
||||
auto infer_d = dynamic_shape::GenInferNode(after_d);
|
||||
auto init_d = dynamic_shape::GenInitNode(after_d);
|
||||
|
||||
auto infer_de = dynamic_shape::GenInferNode(after_dync_end);
|
||||
auto init_de = dynamic_shape::GenInitNode(after_dync_end);
|
||||
auto update_de = dynamic_shape::GenUpdateNode(after_dync_end);
|
||||
|
||||
auto infer_e = dynamic_shape::GenInferNode(after_e, true);
|
||||
auto init_e = dynamic_shape::GenInitNode(after_e);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_a, init_a});
|
||||
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{init_tuple, infer_tuple});
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{after_tuple, init_tuple});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{update_tuple, after_tuple});
|
||||
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tuple, infer_a});
|
||||
|
||||
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
|
||||
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{after_b, init_b});
|
||||
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, infer_tuple});
|
||||
auto depend9 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, update_tuple});
|
||||
|
||||
auto depend10 = TestCreateDepend(after_fg, AnfNodePtrList{init_c, infer_c});
|
||||
auto depend11 = TestCreateDepend(after_fg, AnfNodePtrList{after_c, init_c});
|
||||
auto depend12 = TestCreateDepend(after_fg, AnfNodePtrList{infer_c, infer_b});
|
||||
|
||||
auto depend13 = TestCreateDepend(after_fg, AnfNodePtrList{init_d, infer_d});
|
||||
auto depend14 = TestCreateDepend(after_fg, AnfNodePtrList{after_d, init_d});
|
||||
auto depend15 = TestCreateDepend(after_fg, AnfNodePtrList{infer_d, infer_tuple});
|
||||
auto depend16 = TestCreateDepend(after_fg, AnfNodePtrList{infer_d, update_tuple});
|
||||
|
||||
auto depend17 = TestCreateDepend(after_fg, AnfNodePtrList{init_de, infer_de});
|
||||
auto depend18 = TestCreateDepend(after_fg, AnfNodePtrList{after_dync_end, init_de});
|
||||
auto depend19 = TestCreateDepend(after_fg, AnfNodePtrList{update_de, after_dync_end});
|
||||
auto depend20 = TestCreateDepend(after_fg, AnfNodePtrList{infer_de, infer_d});
|
||||
|
||||
auto depend21 = TestCreateDepend(after_fg, AnfNodePtrList{init_e, infer_e});
|
||||
auto depend22 = TestCreateDepend(after_fg, AnfNodePtrList{after_e, init_e});
|
||||
auto depend23 = TestCreateDepend(after_fg, AnfNodePtrList{infer_e, infer_de});
|
||||
auto depend24 = TestCreateDepend(after_fg, AnfNodePtrList{infer_e, update_de});
|
||||
|
||||
auto make_tuple = TestCreateMakeTuple(
|
||||
after_fg,
|
||||
AnfNodePtrList{after_make_tuple, depend0, depend1, depend2, depend3, depend4, depend5, depend6, depend7,
|
||||
depend8, depend9, depend10, depend11, depend12, depend13, depend14, depend15, depend16,
|
||||
depend17, depend18, depend19, depend20, depend21, depend22, depend23, depend24});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_make_tuple->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Dynamic op + depend.
|
||||
/// before:
|
||||
/// a = Unique(%p) (1, -1)
|
||||
/// |
|
||||
/// b = A(a) (1, -1) B(%p) (1, -1)
|
||||
/// \ /
|
||||
/// depend
|
||||
/// Expectation: Graph as following.
|
||||
/// after:
|
||||
/// Unique_Update Unique(%p) Unique_Init Unique_Infer
|
||||
/// \ / | \ / \ / |
|
||||
/// depend | depend depend |
|
||||
/// | |
|
||||
/// B A A_Init A_Infer |
|
||||
/// \ / \ / \ / \ |
|
||||
/// depend depend depend depend
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_with_depend) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_p = TestCreateParameter(before_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_a_node = TestCreateCNode(before_fg, "A", AnfNodePtrList{before_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_b_node =
|
||||
TestCreateCNode(before_fg, "B", AnfNodePtrList{before_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_depend_node = TestCreateDepend(before_fg, AnfNodePtrList{before_a_node, before_b_node});
|
||||
before_fg->set_output(before_depend_node);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
auto after_p = TestCreateParameter(after_fg, "p", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_a_node = TestCreateCNode(after_fg, "A", AnfNodePtrList{after_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_b_node =
|
||||
TestCreateCNode(after_fg, "B", AnfNodePtrList{after_p}, TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_depend_node = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, after_b_node});
|
||||
|
||||
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
|
||||
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
|
||||
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a_node);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a_node);
|
||||
|
||||
auto infer_b = dynamic_shape::GenInferNode(after_b_node);
|
||||
auto init_b = dynamic_shape::GenInitNode(after_b_node);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
|
||||
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, infer_uniq});
|
||||
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_a, update_uniq});
|
||||
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
|
||||
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{after_b_node, init_b});
|
||||
|
||||
auto make_tuple = TestCreateMakeTuple(after_fg, AnfNodePtrList{after_depend_node, depend0, depend1, depend2, depend3,
|
||||
depend4, depend5, depend6, depend7, depend8});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_a_node->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Dynamic op + monad.
|
||||
/// before:
|
||||
/// u = kUMond()
|
||||
/// / |
|
||||
/// a = Assign(v, x1, u) -- c = UpdateState(u, a) -- d = Load(v, c)
|
||||
/// \ /
|
||||
/// e = UpdateState(c,d) -- Unique(%p)
|
||||
/// \ /
|
||||
/// depend
|
||||
/// Expectation: Graph as following.
|
||||
/// after:
|
||||
/// u = kUMond()
|
||||
/// / |
|
||||
/// Assign_Infer Assign_Init Assign -- UpdateState(u, a) -- Load(v, c)
|
||||
/// \ / \ / \ /
|
||||
/// depend depend e = UpdateState(c,d) -- Unique(%p) Unique_Init Unique_Infer
|
||||
/// \ / | \ / \ /
|
||||
/// depend | depend depend
|
||||
/// | Unique_Update
|
||||
/// | /
|
||||
/// depend
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_with_monad) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
auto before_v = TestCreateParameter(before_fg, "v", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_x1 = TestCreateParameter(before_fg, "x1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto before_u = NewValueNode(kUMonad);
|
||||
auto before_assign = TestCreateCNode(before_fg, "Assign", AnfNodePtrList{before_v, before_x1, before_u},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto before_update_state_1 =
|
||||
TestCreateCNode(before_fg, "UpdateState", AnfNodePtrList{before_u, before_assign}, kUMonad->ToAbstract());
|
||||
auto before_load =
|
||||
TestCreateCNode(before_fg, "Load", AnfNodePtrList{before_v, before_update_state_1}, before_v->abstract());
|
||||
auto before_update_state_2 = TestCreateCNode(
|
||||
before_fg, "UpdateState", AnfNodePtrList{before_update_state_1, before_load}, kUMonad->ToAbstract());
|
||||
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_load},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
|
||||
auto before_depend_node = TestCreateDepend(before_fg, AnfNodePtrList{before_uniq_node, before_update_state_2});
|
||||
before_fg->set_output(before_depend_node);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
auto after_v = TestCreateParameter(after_fg, "v", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_x1 = TestCreateParameter(after_fg, "x1", TestCreateTensor(kFloat32, std::vector<int64_t>{1, 10}));
|
||||
auto after_u = NewValueNode(kUMonad);
|
||||
auto after_assign = TestCreateCNode(after_fg, "Assign", AnfNodePtrList{after_v, after_x1, after_u},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
auto after_update_state_1 =
|
||||
TestCreateCNode(after_fg, "UpdateState", AnfNodePtrList{after_u, after_assign}, kUMonad->ToAbstract());
|
||||
auto after_load =
|
||||
TestCreateCNode(after_fg, "Load", AnfNodePtrList{after_v, after_update_state_1}, after_v->abstract());
|
||||
auto after_update_state_2 =
|
||||
TestCreateCNode(after_fg, "UpdateState", AnfNodePtrList{after_update_state_1, after_load}, kUMonad->ToAbstract());
|
||||
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_load},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{1, -1}));
|
||||
|
||||
auto after_depend_node = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, after_update_state_2});
|
||||
|
||||
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
|
||||
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
|
||||
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
|
||||
|
||||
auto infer_assign = dynamic_shape::GenInferNode(after_assign);
|
||||
auto init_assign = dynamic_shape::GenInitNode(after_assign);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_assign, infer_assign});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_assign, init_assign});
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
|
||||
|
||||
auto make_tuple =
|
||||
TestCreateMakeTuple(after_fg, AnfNodePtrList{after_depend_node, depend0, depend1, depend2, depend3, depend4});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_uniq_node->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
|
||||
/// Feature: Dynamic shape
|
||||
/// Description: Need sync case(contain op such as Tile...).
|
||||
/// Expectation: Graph as expected.
|
||||
TEST_F(TestDynamicShapePass, test_dynamic_shape_pass_sync) {
|
||||
// construct before graph
|
||||
auto before_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(before_fg != nullptr);
|
||||
|
||||
const auto &kTile = prim::kPrimTile->name();
|
||||
|
||||
auto before_p1 = TestCreateParameter(before_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto before_p2 = TestCreateParameter(before_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto before_uniq_node = TestCreateCNode(before_fg, "Unique", AnfNodePtrList{before_p1},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, -1}));
|
||||
auto before_tile1_node = TestCreateCNode(before_fg, kTile, AnfNodePtrList{before_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto before_a_node =
|
||||
TestCreateCNode(before_fg, "A", AnfNodePtrList{before_p2}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto before_b_node =
|
||||
TestCreateCNode(before_fg, "B", AnfNodePtrList{before_a_node}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto before_tile2_node = TestCreateCNode(before_fg, kTile, AnfNodePtrList{before_a_node, before_b_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto before_add_node = TestCreateCNode(before_fg, "Add", AnfNodePtrList{before_tile1_node, before_tile2_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
before_fg->set_output(before_add_node);
|
||||
|
||||
// run pass
|
||||
AscendDynamicShapeConvert(before_fg);
|
||||
|
||||
// construct after graph
|
||||
auto after_fg = std::make_shared<session::KernelGraph>();
|
||||
ASSERT_TRUE(after_fg != nullptr);
|
||||
|
||||
auto after_p1 = TestCreateParameter(after_fg, "p1", TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto after_p2 = TestCreateParameter(after_fg, "p2", TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto after_uniq_node = TestCreateCNode(after_fg, "Unique", AnfNodePtrList{after_p1},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, -1}));
|
||||
auto after_tile1_node = TestCreateCNode(after_fg, kTile, AnfNodePtrList{after_uniq_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto after_a_node =
|
||||
TestCreateCNode(after_fg, "A", AnfNodePtrList{after_p2}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto after_b_node =
|
||||
TestCreateCNode(after_fg, "B", AnfNodePtrList{after_a_node}, TestCreateTensor(kFloat32, std::vector<int64_t>{2}));
|
||||
auto after_tile2_node = TestCreateCNode(after_fg, kTile, AnfNodePtrList{after_a_node, after_b_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
auto after_add_node = TestCreateCNode(after_fg, "Add", AnfNodePtrList{after_tile1_node, after_tile2_node},
|
||||
TestCreateTensor(kFloat32, std::vector<int64_t>{2, 10}));
|
||||
|
||||
auto infer_uniq = dynamic_shape::GenInferNode(after_uniq_node);
|
||||
auto init_uniq = dynamic_shape::GenInitNode(after_uniq_node);
|
||||
auto update_uniq = dynamic_shape::GenUpdateNode(after_uniq_node);
|
||||
|
||||
auto infer_tile1 = dynamic_shape::GenInferNode(after_tile1_node);
|
||||
auto init_tile1 = dynamic_shape::GenInitNode(after_tile1_node);
|
||||
|
||||
auto infer_a = dynamic_shape::GenInferNode(after_a_node, true);
|
||||
auto init_a = dynamic_shape::GenInitNode(after_a_node);
|
||||
|
||||
auto infer_b = dynamic_shape::GenInferNode(after_b_node, true);
|
||||
auto init_b = dynamic_shape::GenInitNode(after_b_node);
|
||||
auto update_b = dynamic_shape::GenUpdateNode(after_b_node, true);
|
||||
|
||||
auto infer_tile2 = dynamic_shape::GenInferNode(after_tile2_node, true);
|
||||
auto init_tile2 = dynamic_shape::GenInitNode(after_tile2_node);
|
||||
|
||||
auto infer_add = dynamic_shape::GenInferNode(after_add_node, true);
|
||||
auto init_add = dynamic_shape::GenInitNode(after_add_node);
|
||||
|
||||
auto depend0 = TestCreateDepend(after_fg, AnfNodePtrList{init_uniq, infer_uniq});
|
||||
auto depend1 = TestCreateDepend(after_fg, AnfNodePtrList{after_uniq_node, init_uniq});
|
||||
auto depend2 = TestCreateDepend(after_fg, AnfNodePtrList{update_uniq, after_uniq_node});
|
||||
|
||||
auto depend3 = TestCreateDepend(after_fg, AnfNodePtrList{init_tile1, infer_tile1});
|
||||
auto depend4 = TestCreateDepend(after_fg, AnfNodePtrList{after_tile1_node, init_tile1});
|
||||
auto depend5 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile1, infer_uniq});
|
||||
auto depend6 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile1, update_uniq});
|
||||
|
||||
auto depend7 = TestCreateDepend(after_fg, AnfNodePtrList{init_a, infer_a});
|
||||
auto depend8 = TestCreateDepend(after_fg, AnfNodePtrList{after_a_node, init_a});
|
||||
|
||||
auto depend9 = TestCreateDepend(after_fg, AnfNodePtrList{init_b, infer_b});
|
||||
auto depend10 = TestCreateDepend(after_fg, AnfNodePtrList{after_b_node, init_b});
|
||||
auto depend11 = TestCreateDepend(after_fg, AnfNodePtrList{infer_b, infer_a});
|
||||
|
||||
auto depend12 = TestCreateDepend(after_fg, AnfNodePtrList{init_tile2, infer_tile2});
|
||||
auto depend13 = TestCreateDepend(after_fg, AnfNodePtrList{after_tile2_node, init_tile2});
|
||||
auto depend14 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, infer_a});
|
||||
auto depend15 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, infer_b});
|
||||
auto depend16 = TestCreateDepend(after_fg, AnfNodePtrList{update_b, after_b_node});
|
||||
auto depend17 = TestCreateDepend(after_fg, AnfNodePtrList{infer_tile2, update_b});
|
||||
|
||||
auto depend18 = TestCreateDepend(after_fg, AnfNodePtrList{init_add, infer_add});
|
||||
auto depend19 = TestCreateDepend(after_fg, AnfNodePtrList{after_add_node, init_add});
|
||||
auto depend20 = TestCreateDepend(after_fg, AnfNodePtrList{infer_add, infer_tile1});
|
||||
auto depend21 = TestCreateDepend(after_fg, AnfNodePtrList{infer_add, infer_tile2});
|
||||
|
||||
auto make_tuple = TestCreateMakeTuple(
|
||||
after_fg, AnfNodePtrList{after_add_node, depend0, depend1, depend2, depend3, depend4, depend5, depend6,
|
||||
depend7, depend8, depend9, depend10, depend11, depend12, depend13, depend14,
|
||||
depend15, depend16, depend17, depend18, depend19, depend20, depend21});
|
||||
auto get_item = TestCreateCNode(after_fg, "TupleGetItem",
|
||||
AnfNodePtrList{make_tuple, NewValueNode(SizeToLong(kTupleFirstItemIndex))},
|
||||
after_add_node->abstract());
|
||||
after_fg->set_output(get_item);
|
||||
|
||||
// assert
|
||||
EXPECT_TRUE(CheckEqualGraph(after_fg, before_fg));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue