extract common as an independent shared library
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
dfc6cbb6df
commit
f49b195c39
|
@ -90,7 +90,7 @@ install(
|
|||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_core
|
||||
TARGETS mindspore_core mindspore_common
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -67,7 +67,7 @@ install(
|
|||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_core
|
||||
TARGETS mindspore_core mindspore_common
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -27,7 +27,7 @@ install(
|
|||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_core
|
||||
TARGETS mindspore_core mindspore_common
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -74,7 +74,7 @@ install(
|
|||
)
|
||||
|
||||
install(
|
||||
TARGETS mindspore_core
|
||||
TARGETS mindspore_core mindspore_common
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -230,21 +230,21 @@ endif()
|
|||
|
||||
## make sub objects
|
||||
set(SUB_COMP
|
||||
transform/graph_ir
|
||||
transform/express_ir
|
||||
frontend/optimizer
|
||||
frontend/parallel
|
||||
frontend/operator
|
||||
pipeline/jit
|
||||
pipeline/pynative
|
||||
debug pybind_api utils profiler ps fl distributed
|
||||
debug pybind_api
|
||||
profiler ps fl distributed
|
||||
kernel
|
||||
common
|
||||
common/mem_reuse
|
||||
backend/common/optimizer
|
||||
backend/common/pass
|
||||
backend/common/session
|
||||
backend/common/somas
|
||||
common/graph_kernel
|
||||
backend/graph_compiler
|
||||
runtime/collective
|
||||
runtime/device
|
||||
|
@ -278,6 +278,31 @@ endforeach()
|
|||
set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME)
|
||||
add_library(mindspore STATIC ${SUB_OBJECTS_SRC})
|
||||
|
||||
set(COMMON_SUB_COMP
|
||||
transform/graph_ir
|
||||
utils
|
||||
common
|
||||
)
|
||||
|
||||
foreach(_comp ${COMMON_SUB_COMP})
|
||||
add_subdirectory(${_comp})
|
||||
string(REPLACE "/" "_" sub ${_comp})
|
||||
if(TARGET _mindspore_${sub}_obj)
|
||||
list(APPEND COMMON_SUB_OBJECTS_SRC $<TARGET_OBJECTS:_mindspore_${sub}_obj>)
|
||||
add_dependencies(_mindspore_${sub}_obj proto_input mindspore_core)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
target_compile_definitions(_mindspore_${sub}_obj PRIVATE COMMON_DLL)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_library(mindspore_common SHARED ${COMMON_SUB_OBJECTS_SRC})
|
||||
target_link_libraries(mindspore_common PRIVATE mindspore_core proto_input mindspore::protobuf)
|
||||
set_target_properties(mindspore_common PROPERTIES INSTALL_RPATH $ORIGIN)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
target_link_libraries(mindspore_common PRIVATE mindspore::pybind11_module)
|
||||
endif()
|
||||
|
||||
if(ENABLE_DEBUGGER)
|
||||
# debugger: link grpc
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
|
@ -402,12 +427,14 @@ set_target_properties(_c_expression PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
|||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
target_link_libraries(mindspore PUBLIC mindspore::pybind11_module)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core
|
||||
mindspore_common)
|
||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
target_link_libraries(mindspore PUBLIC proto_input mindspore::protobuf
|
||||
mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::eigen mindspore::json)
|
||||
target_link_libraries(mindspore PUBLIC mindspore::event_core ps_cache)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-all_load mindspore proto_input -Wl,-noall_load mindspore_core)
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,-all_load mindspore proto_input -Wl,-noall_load mindspore_core
|
||||
mindspore_common)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||
else()
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
|
@ -419,7 +446,7 @@ else()
|
|||
endif()
|
||||
endif()
|
||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive
|
||||
mindspore_core)
|
||||
mindspore_core mindspore_common)
|
||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||
endif()
|
||||
|
||||
|
@ -477,3 +504,4 @@ if(ENABLE_D)
|
|||
endif()
|
||||
|
||||
add_subdirectory(cxx_api)
|
||||
# include(${CMAKE_CURRENT_SOURCE_DIR}/sharedlib.cmake)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/optimizer/fusion_id_allocator.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -35,19 +35,19 @@ bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) const {
|
|||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode);
|
||||
return common::AnfAlgo::HasNodeAttr(kAttrFusionId, cnode);
|
||||
}
|
||||
|
||||
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) {
|
||||
if (HasFusionIdAttr(node)) {
|
||||
return AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
|
||||
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) {
|
||||
ValuePtr fusion_id_v = MakeValue(id);
|
||||
AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,14 +22,13 @@
|
|||
#include <set>
|
||||
#include <deque>
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynaminc_shape_util.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
@ -47,8 +46,8 @@ void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodeP
|
|||
for (auto &orig_node : orig_nodes) {
|
||||
if (AnfUtils::IsRealCNodeKernel(orig_node)) {
|
||||
auto orig_cnode = orig_node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
||||
AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
||||
}
|
||||
orig_real_cnodes.push_back(orig_node);
|
||||
}
|
||||
|
@ -155,7 +154,7 @@ CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_si
|
|||
|
||||
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto real_input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
if (real_input_tensor_num != input_tensor_size) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
|
||||
<< "] of node [" + cnode->DebugString() + "] is not equal to " << input_tensor_size
|
||||
|
@ -227,8 +226,9 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
|
|||
idx->set_abstract(abstract_scalar);
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(type_ptr, i)},
|
||||
{AnfAlgo::GetOutputInferShape(node, shape_ptr, i)}, tuple_getitem.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(type_ptr, i)},
|
||||
{common::AnfAlgo::GetOutputInferShape(node, shape_ptr, i)},
|
||||
tuple_getitem.get());
|
||||
(*outputs).push_back(tuple_getitem);
|
||||
}
|
||||
}
|
||||
|
@ -295,54 +295,12 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
|
|||
return tensor;
|
||||
}
|
||||
|
||||
bool IsNopNode(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto target = GetCNodeTarget(node);
|
||||
if (target == kCPUDevice) {
|
||||
return false;
|
||||
}
|
||||
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
|
||||
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static mindspore::HashSet<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
|
||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
||||
kFlattenGradOpName, prim::kPrimReformat->name()};
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
return false;
|
||||
}
|
||||
auto input0 = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
if (!input0->isa<ValueNode>()) {
|
||||
return false;
|
||||
}
|
||||
bool is_nop_node = false;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrNopOp, cnode)) {
|
||||
is_nop_node = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrNopOp);
|
||||
}
|
||||
if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end() && !is_nop_node) {
|
||||
return false;
|
||||
}
|
||||
const size_t kNopNodeInputSize = 2;
|
||||
if (cnode->size() != kNopNodeInputSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsAllNopNode(const session::KernelGraph *const graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &cnode : execution_order) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsNopNode(cnode)) {
|
||||
if (!common::AnfAlgo::IsNopNode(cnode)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -352,7 +310,7 @@ bool IsAllNopNode(const session::KernelGraph *const graph) {
|
|||
bool NeedHideNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if node is not a nop node, keep it in execution order
|
||||
if (!IsNopNode(node)) {
|
||||
if (!common::AnfAlgo::IsNopNode(node)) {
|
||||
return false;
|
||||
}
|
||||
// if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
|
||||
|
@ -378,7 +336,7 @@ void HideNopNode(session::KernelGraph *const graph) {
|
|||
for (auto &cnode : execution_order) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
} else {
|
||||
new_nodes.push_back(cnode);
|
||||
}
|
||||
|
@ -402,7 +360,7 @@ void RemoveNopNode(session::KernelGraph *const graph) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// ignore nop node itself
|
||||
if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
continue;
|
||||
}
|
||||
// Replace the input which is nop node
|
||||
|
@ -413,7 +371,7 @@ void RemoveNopNode(session::KernelGraph *const graph) {
|
|||
auto input = cnode->input(i);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto cinput = input->cast<CNodePtr>();
|
||||
if (cinput == nullptr || !IsNopNode(cinput)) {
|
||||
if (cinput == nullptr || !common::AnfAlgo::IsNopNode(cinput)) {
|
||||
new_inputs.push_back(input);
|
||||
continue;
|
||||
}
|
||||
|
@ -454,7 +412,7 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
|
||||
auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
|
||||
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
|
||||
(cnode_name == prim::kPrimUpdateState->name())) {
|
||||
continue;
|
||||
|
@ -477,20 +435,20 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
auto cnode_name = AnfAlgo::GetCNodeName(output_info.first);
|
||||
auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
|
||||
if ((cnode_name == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) ||
|
||||
(cnode_name == prim::kPrimUpdateState->name())) {
|
||||
continue;
|
||||
}
|
||||
size_t used_output_index;
|
||||
if (cnode_name == prim::kPrimTupleGetItem->name()) {
|
||||
used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
||||
} else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
|
||||
used_output_index = common::AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
||||
} else if (common::AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
|
||||
used_output_index = output_index;
|
||||
} else {
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
|
||||
auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(output_info.first, IntToSize(output_info.second - 1));
|
||||
if (kernel_with_index.first.get() != node.get()) {
|
||||
MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
|
||||
MS_LOG(EXCEPTION) << "Get used node failed for op[" << common::AnfAlgo::GetCNodeName(node) << "]";
|
||||
}
|
||||
used_output_index = kernel_with_index.second;
|
||||
}
|
||||
|
@ -519,7 +477,7 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
}
|
||||
for (const auto &output : *output_node_list) {
|
||||
auto out_node = output.first;
|
||||
auto name = AnfAlgo::GetCNodeName(out_node);
|
||||
auto name = common::AnfAlgo::GetCNodeName(out_node);
|
||||
if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
|
||||
name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
|
||||
auto result = IsNotRealUsedByOthers(graph, out_node);
|
||||
|
@ -731,7 +689,7 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
|||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(primitive->name(), ®)) {
|
||||
return input_abstract;
|
||||
}
|
||||
if (AnfAlgo::HasDynamicShapeFlag(primitive)) {
|
||||
if (common::AnfAlgo::HasDynamicShapeFlag(primitive)) {
|
||||
return input_abstract;
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -895,12 +853,12 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
|
||||
return common::AnfAlgo::HasNodeAttr(attr_name, cnode) && common::AnfAlgo::GetNodeAttr<bool>(node, attr_name);
|
||||
}
|
||||
|
||||
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
TypeId data_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -923,7 +881,7 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
|||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(value_node);
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
|
@ -942,8 +900,8 @@ void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &g
|
|||
for (const auto &node_index : node_users) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
|
||||
auto depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
manager->SetEdge(depend, node_index.second, new_node);
|
||||
|
@ -987,16 +945,16 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
|
|||
for (size_t idx = 0; idx < node_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
(void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)inputs_device_type.emplace_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
(void)inputs_device_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
(void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)outputs_device_type.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
(void)outputs_shape.emplace_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
(void)outputs_device_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
(void)outputs_shape.emplace_back(common::AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
}
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
|
@ -1010,14 +968,14 @@ std::vector<int64_t> GetNodeOutputUsedNum(const session::KernelGraph &kernel_gra
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
std::vector<int64_t> output_used_num(output_num, 0);
|
||||
if (output_num == 1) {
|
||||
output_used_num[0] = SizeToLong(manager->node_users()[node].size());
|
||||
} else {
|
||||
for (auto out_getitem : manager->node_users()[node]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
|
|
|
@ -174,8 +174,6 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
|
|||
|
||||
bool IsAllNopNode(const session::KernelGraph *const graph);
|
||||
|
||||
bool IsNopNode(const AnfNodePtr &node);
|
||||
|
||||
void HideNopNode(session::KernelGraph *const graph);
|
||||
|
||||
void RemoveNopNode(session::KernelGraph *const graph);
|
||||
|
|
|
@ -22,7 +22,8 @@
|
|||
#include "ir/manager.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -34,7 +35,7 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(out_caller_map);
|
||||
auto inputs = cnode->inputs();
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
auto partial_node = dyn_cast<CNode>(inputs.at(kSwitchBranchIndex));
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
const auto &partial_inputs = partial_node->inputs();
|
||||
|
@ -44,7 +45,7 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
|
|||
auto switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
|
||||
MS_EXCEPTION_IF_NULL(switch_subgraph);
|
||||
(*out_caller_map)[switch_subgraph->output()] = cnode;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
auto call_subgraph = GetValueNode<FuncGraphPtr>(inputs.at(kCallArgsIndex));
|
||||
MS_EXCEPTION_IF_NULL(call_subgraph);
|
||||
(*out_caller_map)[call_subgraph->output()] = cnode;
|
||||
|
@ -85,7 +86,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto end_label = kernel_graph->get_end_goto();
|
||||
if (cnode == end_label && AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
|
||||
if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
|
||||
kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
|
@ -100,7 +101,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
(void)todo.emplace_back(const_func_graph->output(), const_func_graph);
|
||||
}
|
||||
} else if (new_node && new_node->isa<CNode>()) {
|
||||
if (AnfAlgo::IsGraphKernel(new_node)) {
|
||||
if (common::AnfAlgo::IsGraphKernel(new_node)) {
|
||||
(void)todo.emplace_back(new_node, func_graph);
|
||||
}
|
||||
auto cnode = new_node->cast<CNodePtr>();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -19,11 +19,10 @@
|
|||
#include <deque>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -50,11 +49,11 @@ TypeId CacheManager::GetOutputType(const AnfNodePtr &node, size_t index) {
|
|||
}
|
||||
return kTypeUnknown;
|
||||
}
|
||||
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto output_nums = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
std::map<size_t, TypeId> index_to_types;
|
||||
TypeId result = kTypeUnknown;
|
||||
for (size_t i = 0; i < output_nums; i++) {
|
||||
auto output_type = AnfAlgo::GetOutputInferDataType(node, i);
|
||||
auto output_type = common::AnfAlgo::GetOutputInferDataType(node, i);
|
||||
(void)index_to_types.emplace(i, output_type);
|
||||
if (index == i) {
|
||||
result = output_type;
|
||||
|
@ -75,11 +74,11 @@ std::vector<size_t> CacheManager::GetOutputShape(const AnfNodePtr &node, size_t
|
|||
}
|
||||
return {};
|
||||
}
|
||||
auto output_nums = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto output_nums = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
std::map<size_t, std::vector<size_t>> index_to_shapes;
|
||||
std::vector<size_t> result = {};
|
||||
for (size_t i = 0; i < output_nums; i++) {
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(node, i);
|
||||
auto output_shape = common::AnfAlgo::GetOutputInferShape(node, i);
|
||||
(void)index_to_shapes.emplace(i, output_shape);
|
||||
if (index == i) {
|
||||
result = output_shape;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -28,7 +29,7 @@ void ClonePrimitive(const AnfNodePtr &node) {
|
|||
// Several CNode may share a primitive pointer, so we clone the primitive before setting attr.
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) return;
|
||||
auto prim_node = NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone());
|
||||
auto prim_node = NewValueNode(common::AnfAlgo::GetCNodePrimitive(cnode)->Clone());
|
||||
cnode->set_input(kAnfPrimitiveIndex, prim_node);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -38,20 +39,20 @@ void ProcessCast(const AnfNodePtr &node) {
|
|||
std::vector<std::string> input_names = {"x", kAttrDstType};
|
||||
std::vector<std::string> output_names = {"output"};
|
||||
ClonePrimitive(node);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), node);
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), node);
|
||||
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), node);
|
||||
}
|
||||
|
||||
void ProcessMatMul(const AnfNodePtr &node) {
|
||||
ClonePrimitive(node);
|
||||
TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
|
||||
AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), node);
|
||||
auto left_format = AnfAlgo::GetInputFormat(node, 0);
|
||||
auto right_format = AnfAlgo::GetInputFormat(node, 1);
|
||||
AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), node);
|
||||
AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), node);
|
||||
common::AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), node);
|
||||
common::AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), node);
|
||||
}
|
||||
|
||||
const AnfNodePtr AddAkgKernelAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
|
|
@ -16,9 +16,8 @@
|
|||
|
||||
#include "backend/common/pass/add_dynamic_shape_attr.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -26,7 +25,7 @@ const AnfNodePtr AddDynamicShapeAttr::Process(const FuncGraphPtr &func_graph, co
|
|||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
MS_LOG(DEBUG) << "Set Dynamic Shape Attr to Node:" << node->fullname_with_scope();
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_ADD_DYNAMIC_SHAPE_ATTR_H
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -24,9 +24,7 @@
|
|||
#include "utils/hash_set.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -38,12 +36,12 @@ bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const
|
|||
for (const auto &node_index : manager->node_users()[cnode]) {
|
||||
auto output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
|
||||
if (CheckOP(manager, output, set)) {
|
||||
return true;
|
||||
}
|
||||
} else if (output->isa<CNode>()) {
|
||||
auto name = AnfAlgo::GetCNodeName(output);
|
||||
auto name = common::AnfAlgo::GetCNodeName(output);
|
||||
if (set.find(name) != set.end()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -59,7 +57,7 @@ void AddAttrTraining(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
if (manager->node_users().find(cnode) == manager->node_users().end()) {
|
||||
return;
|
||||
}
|
||||
auto set = MarkOp[AnfAlgo::GetCNodeName(cnode)];
|
||||
auto set = MarkOp[common::AnfAlgo::GetCNodeName(cnode)];
|
||||
if (CheckOP(manager, cnode, set)) {
|
||||
cnode->AddAttr(kAttrIsTraining, MakeValue(true));
|
||||
} else {
|
||||
|
@ -70,14 +68,14 @@ void AddAttrTraining(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
|
||||
const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
if (node == nullptr || func_graph == nullptr || common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto name = AnfAlgo::GetCNodeName(node);
|
||||
auto name = common::AnfAlgo::GetCNodeName(node);
|
||||
auto iter = MarkOp.find(name);
|
||||
if (iter == MarkOp.end()) {
|
||||
return nullptr;
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include <algorithm>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -35,11 +35,12 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!AnfAlgo::IsAllgather(cnode) || !AnfAlgo::IsFusion(cnode) || !AnfAlgo::IsFromParallelOptimizer(cnode)) {
|
||||
if (!common::AnfAlgo::IsAllgather(cnode) || !common::AnfAlgo::IsFusion(cnode) ||
|
||||
!common::AnfAlgo::IsFromParallelOptimizer(cnode)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::IsRecompute(cnode)) {
|
||||
int64_t fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
if (common::AnfAlgo::IsRecompute(cnode)) {
|
||||
int64_t fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
if (std::find(parallel_optimizer_recompute_allgather_fusion_ids.begin(),
|
||||
parallel_optimizer_recompute_allgather_fusion_ids.end(),
|
||||
fusion_id) == parallel_optimizer_recompute_allgather_fusion_ids.end()) {
|
||||
|
@ -52,10 +53,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
|
|||
parallel_optimizer_recompute_allgathers.push_back(node);
|
||||
}
|
||||
} else {
|
||||
int64_t unrecompute_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
int64_t unrecompute_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
|
||||
unrecompute_max_fusion_id = std::max(unrecompute_fusion_id, unrecompute_max_fusion_id);
|
||||
bool would_be_recomputed =
|
||||
AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute);
|
||||
bool would_be_recomputed = common::AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) &&
|
||||
common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute);
|
||||
auto [iter, inserted] =
|
||||
forward_allgather_recompute_value_in_fusion_group.emplace(unrecompute_fusion_id, would_be_recomputed);
|
||||
if (!inserted && iter->second != would_be_recomputed) {
|
||||
|
@ -79,14 +80,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
|||
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
|
||||
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
|
||||
int64_t current_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
|
||||
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
}
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
|
||||
int64_t current_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id = current_fusion_id + unrecompute_max_fusion_id - recompute_min_fusion_id + 2;
|
||||
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -97,7 +98,7 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
|||
bool changed = false;
|
||||
for (auto &node : parallel_optimizer_recompute_allgathers) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto depend_node = AnfAlgo::GetInputNode(cnode, 0);
|
||||
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
|
||||
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
|
||||
auto depend_cnode = depend_node->cast<CNodePtr>();
|
||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||
|
@ -108,17 +109,17 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
|||
continue;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
allgather_next_node, AnfAlgo::GetInputNode(depend_cnode, 1)};
|
||||
allgather_next_node, common::AnfAlgo::GetInputNode(depend_cnode, 1)};
|
||||
auto new_depend = graph->NewCNode(inputs);
|
||||
new_depend->set_abstract(depend_node->abstract());
|
||||
manager->SetEdge(node, 1, AnfAlgo::GetInputNode(depend_cnode, 0));
|
||||
manager->SetEdge(node, 1, common::AnfAlgo::GetInputNode(depend_cnode, 0));
|
||||
(void)manager->Replace(allgather_next_node, new_depend);
|
||||
changed = true;
|
||||
}
|
||||
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
|
||||
IsPrimitiveCNode(AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
|
||||
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
|
||||
auto cast_cnode = depend_node->cast<CNodePtr>();
|
||||
auto cast_depend_node = AnfAlgo::GetInputNode(cast_cnode, 0);
|
||||
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
|
||||
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
|
||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||
for (auto &node_pair : allgather_node_set) {
|
||||
|
@ -128,10 +129,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
|||
continue;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
allgather_next_node, AnfAlgo::GetInputNode(cast_depend_cnode, 1)};
|
||||
allgather_next_node, common::AnfAlgo::GetInputNode(cast_depend_cnode, 1)};
|
||||
auto new_depend = graph->NewCNode(inputs);
|
||||
new_depend->set_abstract(cast_depend_node->abstract());
|
||||
manager->SetEdge(depend_node, 1, AnfAlgo::GetInputNode(cast_depend_cnode, 0));
|
||||
manager->SetEdge(depend_node, 1, common::AnfAlgo::GetInputNode(cast_depend_cnode, 0));
|
||||
(void)manager->Replace(allgather_next_node, new_depend);
|
||||
changed = true;
|
||||
}
|
||||
|
|
|
@ -16,10 +16,10 @@
|
|||
#include "backend/common/pass/common_subexpression_elimination.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/flags.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -28,15 +28,15 @@ bool HasSideEffectAttr(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(GRAPH_FLAG_SIDE_EFFECT, cnode)) {
|
||||
return false;
|
||||
}
|
||||
return AnfAlgo::GetNodeAttr<bool>(cnode, GRAPH_FLAG_SIDE_EFFECT);
|
||||
return common::AnfAlgo::GetNodeAttr<bool>(cnode, GRAPH_FLAG_SIDE_EFFECT);
|
||||
}
|
||||
|
||||
bool CheckIgnoreCase(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::GetCNodeName(node) != kTransDataOpName) {
|
||||
if (common::AnfAlgo::GetCNodeName(node) != kTransDataOpName) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -44,7 +44,7 @@ bool CheckIgnoreCase(const AnfNodePtr &node) {
|
|||
bool need_ignore = true;
|
||||
auto input_size = cnode->inputs().size() - 1;
|
||||
for (size_t k = 0; k < input_size; ++k) {
|
||||
auto input = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, k), 0).first;
|
||||
auto input = common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(cnode, k), 0).first;
|
||||
if (input != nullptr && input->isa<CNode>()) {
|
||||
need_ignore = false;
|
||||
break;
|
||||
|
@ -58,7 +58,7 @@ bool BackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNode
|
|||
MS_EXCEPTION_IF_NULL(main);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (main->isa<CNode>()) {
|
||||
auto main_name = AnfAlgo::GetCNodeName(main);
|
||||
auto main_name = common::AnfAlgo::GetCNodeName(main);
|
||||
if (main_name == prim::kPrimTensorMove->name() || main_name == prim::kPrimMemCpyAsync->name()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -24,9 +24,10 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -49,29 +50,30 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
|
|||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
int64_t rank_size = 1;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) && AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, cnode) &&
|
||||
common::AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName) {
|
||||
rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
if (rank_size_t == 0) {
|
||||
MS_LOG(EXCEPTION) << "Rank size should not be zero.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t rank_index = 0; rank_index < rank_size_t; ++rank_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||
std::vector<size_t> shape = common::AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||
if (!shape.empty()) {
|
||||
shape[0] /= rank_size_t;
|
||||
}
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
outputs_shape.push_back(common::AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
}
|
||||
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
|
||||
|
@ -86,7 +88,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
|
|||
}
|
||||
|
||||
std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
|
||||
if (attr_fusion == nullptr) {
|
||||
|
@ -106,7 +108,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|||
if (attr_op != nullptr) {
|
||||
op = GetValue<std::string>(attr_op);
|
||||
}
|
||||
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
auto dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
return group + op + std::to_string(fusion) + TypeIdLabel(dtype);
|
||||
}
|
||||
|
||||
|
@ -338,22 +340,23 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
auto final_node = communication_op_info.communication_op_nodes[end_index];
|
||||
size_t node_num = end_index - start_index + 1;
|
||||
int64_t rank_size = 1;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) && AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
|
||||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRankSize, final_node) &&
|
||||
common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName) {
|
||||
rank_size = common::AnfAlgo::GetNodeAttr<int64_t>(final_node, kAttrRankSize);
|
||||
}
|
||||
size_t rank_size_t = LongToSize(rank_size);
|
||||
if (rank_size_t == 0) {
|
||||
MS_LOG(EXCEPTION) << "Rank size should not be zero.";
|
||||
}
|
||||
size_t output_num = node_num * rank_size_t;
|
||||
std::vector<TypeId> dtypes(output_num, AnfAlgo::GetOutputInferDataType(final_node, 0));
|
||||
std::vector<TypeId> dtypes(output_num, common::AnfAlgo::GetOutputInferDataType(final_node, 0));
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
int64_t fusion_total_size = 0;
|
||||
for (size_t i = 0; i < rank_size_t; ++i) {
|
||||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto input_node = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||
std::vector<size_t> shape = common::AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||
if (!shape.empty()) {
|
||||
shape[0] /= rank_size_t;
|
||||
}
|
||||
|
@ -368,32 +371,32 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
fusion_total_size += static_cast<int64_t>(tensor_size);
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, fused_node.get());
|
||||
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
|
||||
const std::vector<std::string> kHcclFusionAttrs = {kAttrFusion, kAttrGroup, kAttrGroupBack,
|
||||
kAttrSrTag, kAttrDestRank, kAttrSrcRank,
|
||||
kAttrDType, kAttrOp, kAttrRankSize};
|
||||
for (const auto &attr : kHcclFusionAttrs) {
|
||||
if (AnfAlgo::HasNodeAttr(attr, final_node)) {
|
||||
AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
|
||||
if (common::AnfAlgo::HasNodeAttr(attr, final_node)) {
|
||||
common::AnfAlgo::CopyNodeAttr(attr, final_node, fused_node);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrShape, final_node)) {
|
||||
std::vector<int64_t> fusion_total_shape{fusion_total_size};
|
||||
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(fusion_total_shape), fused_node);
|
||||
}
|
||||
bool is_recompute =
|
||||
final_node->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(final_node->GetAttr(kAttrDuplicated));
|
||||
if (AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
|
||||
if (common::AnfAlgo::GetCNodeName(final_node) == kAllGatherOpName && is_recompute) {
|
||||
auto fused_cnode = fused_node->cast<CNodePtr>();
|
||||
fused_cnode->AddAttr("duplicated", MakeValue(true));
|
||||
auto fused_prim = GetCNodePrimitive(fused_cnode);
|
||||
auto final_node_prim = GetCNodePrimitive(final_node);
|
||||
fused_prim->set_instance_name(final_node_prim->instance_name());
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
|
||||
AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
|
||||
}
|
||||
return fused_node;
|
||||
}
|
||||
|
@ -457,7 +460,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
mindspore::HashMap<std::string, CommunicationOpInfo> candidate_groups;
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
|
||||
if (node != nullptr && node->isa<CNode>() && common::AnfAlgo::GetCNodeName(node) == op_name_) {
|
||||
std::string key = GetFusionGroupKey(node);
|
||||
if (key.empty()) {
|
||||
continue;
|
||||
|
@ -479,11 +482,12 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto first_node = it.second.communication_op_nodes[0];
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(first_node->debug_info()));
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrIndex, first_node) &&
|
||||
common::AnfAlgo::GetNodeAttr<int64_t>(first_node, kAttrIndex) > 0) {
|
||||
std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(),
|
||||
[](const CNodePtr &a, const CNodePtr &b) {
|
||||
return AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
|
||||
AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
|
||||
return common::AnfAlgo::GetNodeAttr<int64_t>(a, kAttrIndex) <
|
||||
common::AnfAlgo::GetNodeAttr<int64_t>(b, kAttrIndex);
|
||||
});
|
||||
}
|
||||
size_t segment_num = 0;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "backend/common/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/common/optimizer/const_input_to_attr.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -92,13 +92,13 @@ bool CheckValues(const ValuePtrList &strides_values) {
|
|||
|
||||
bool CheckAttrs(const CNodePtr &strided_slice_grad) {
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_grad);
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) ||
|
||||
!AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) ||
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) {
|
||||
MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]";
|
||||
return false;
|
||||
}
|
||||
auto new_axis_mask = AnfAlgo::GetNodeAttr<int64_t>(strided_slice_grad, kAttrNewAxisMask);
|
||||
auto shrink_axis_mask = AnfAlgo::GetNodeAttr<int64_t>(strided_slice_grad, kAttrShrinkAxisMask);
|
||||
auto new_axis_mask = common::AnfAlgo::GetNodeAttr<int64_t>(strided_slice_grad, kAttrNewAxisMask);
|
||||
auto shrink_axis_mask = common::AnfAlgo::GetNodeAttr<int64_t>(strided_slice_grad, kAttrShrinkAxisMask);
|
||||
if (new_axis_mask != 0 || shrink_axis_mask != 0) {
|
||||
MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0";
|
||||
return false;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <string>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
*/
|
||||
#include "backend/common/pass/convert_const_input_to_attr.h"
|
||||
#include "backend/common/optimizer/const_input_to_attr.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -32,49 +31,52 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(common::AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
||||
common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||
if (device != kGPUDevice) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::IsDynamicShape(cnode)) {
|
||||
if (common::AnfAlgo::IsDynamicShape(cnode)) {
|
||||
if (device == kGPUDevice) {
|
||||
if (DynamicShapeConstInputToAttrGPU.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttrGPU.end()) {
|
||||
if (DynamicShapeConstInputToAttrGPU.find(common::AnfAlgo::GetCNodeName(cnode)) ==
|
||||
DynamicShapeConstInputToAttrGPU.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
} else if (device == kCPUDevice) {
|
||||
if (DynamicShapeConstInputToAttrCPU.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttrCPU.end()) {
|
||||
if (DynamicShapeConstInputToAttrCPU.find(common::AnfAlgo::GetCNodeName(cnode)) ==
|
||||
DynamicShapeConstInputToAttrCPU.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
if (DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
|
||||
if (DynamicShapeConstInputToAttr.find(common::AnfAlgo::GetCNodeName(cnode)) ==
|
||||
DynamicShapeConstInputToAttr.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (device == kAscendDevice &&
|
||||
NeedConvertToValueNodeSet.find(AnfAlgo::GetCNodeName(cnode)) != NeedConvertToValueNodeSet.end() &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrNeedConvertToValueNode, cnode)) {
|
||||
NeedConvertToValueNodeSet.find(common::AnfAlgo::GetCNodeName(cnode)) != NeedConvertToValueNodeSet.end() &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrNeedConvertToValueNode, cnode)) {
|
||||
auto input_attrs = reg.GetConstInputAttrInfo();
|
||||
std::vector<size_t> need_convert_to_constant;
|
||||
std::transform(input_attrs.begin(), input_attrs.end(), std::back_inserter(need_convert_to_constant),
|
||||
[](size_t i) { return i + 1; });
|
||||
AnfAlgo::SetNodeAttr(kAttrNeedConvertToValueNode, MakeValue(need_convert_to_constant), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNeedConvertToValueNode, MakeValue(need_convert_to_constant), cnode);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ir/graph_utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
|
@ -65,7 +66,7 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const std::set<std::string> no_need_to_convert_nodes = {kStackOpName};
|
||||
auto node_type = AnfAlgo::GetCNodeName(cnode);
|
||||
auto node_type = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (no_need_to_convert_nodes.find(node_type) != no_need_to_convert_nodes.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -93,13 +94,13 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto new_cnode = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimDepend)) {
|
||||
new_cnode->set_abstract(new_inputs[1]->abstract());
|
||||
} else {
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
}
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
common::AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_cnode);
|
||||
}
|
||||
|
@ -110,8 +111,8 @@ AnfNodePtr ConvertConstInputToTensorInput::ConstInputToTensorInput(const FuncGra
|
|||
|
||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
if (node == nullptr || func_graph == nullptr || common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
#include "backend/common/pass/convert_const_scalar_to_tensor.h"
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -57,7 +58,7 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt
|
|||
|
||||
const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
if (node == nullptr || func_graph == nullptr || common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
return nullptr;
|
||||
}
|
||||
// input is scalar, and link to graph return
|
||||
|
|
|
@ -18,32 +18,29 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
|
||||
std::vector<AnfNodePtr> *plant_inputs) {
|
||||
if (!AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(tuple_input)) {
|
||||
auto abs = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
MS_LOG(WARNING) << "The Function only split the output type is tuple type but got" << abs->ToString();
|
||||
return -1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(plant_inputs);
|
||||
auto input_size = AnfAlgo::GetOutputTensorNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto input_size = common::AnfAlgo::GetOutputTensorNum(tuple_input);
|
||||
if (tuple_input->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(tuple_input, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
// using for graph kernel
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
|
@ -59,18 +56,18 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
|
|||
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||
AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
plant_inputs.push_back(common::AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
size_t input_num = cnode_ptr->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (AnfAlgo::IsTupleOutput(input_node)) {
|
||||
if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
||||
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
|
@ -79,7 +76,7 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|||
}
|
||||
// If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs.
|
||||
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr);
|
||||
cnode_ptr->set_inputs(plant_inputs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -30,7 +29,7 @@ AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNode
|
|||
MS_EXCEPTION_IF_NULL(tuple_anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
if (!AnfAlgo::IsTupleOutput(tuple_anf)) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(tuple_anf)) {
|
||||
return tuple_anf;
|
||||
}
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
|
@ -65,7 +64,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||
auto real_input = common::AnfAlgo::GetTupleGetItemRealInput(cnode);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (!real_input->isa<Parameter>() && !real_input->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
|
@ -77,8 +76,8 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
bool cnode_input_changed = false;
|
||||
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||
const auto &input = cnode->inputs()[i];
|
||||
if (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
|
||||
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
||||
if (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && common::AnfAlgo::IsTupleOutput(input) &&
|
||||
!common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
|
||||
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input));
|
||||
cnode_input_changed = true;
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include "utils/hash_set.h"
|
||||
#include "backend/common/optimizer/const_input_to_attr.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -36,7 +36,7 @@ const AnfNodePtr CustomOpConstInputToAttr::Process(const FuncGraphPtr &, const A
|
|||
if (!IsPrimitiveCNode(cnode, prim::kPrimCustom)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
mindspore::HashSet<size_t> attr_indices;
|
||||
GetCustomOpAttrIndex(primitive, &attr_indices);
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -87,10 +87,10 @@ void ParseAttrDefaultValue(const std::string &op_name, const std::string &attr_n
|
|||
void AddMissingAttrs(const CNodePtr &cnode, kernel::OpImplyType imply_type,
|
||||
const std::unordered_set<std::string> &missing_attrs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
primitive = primitive->Clone();
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, imply_type);
|
||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||
auto all_attrs = op_info_ptr->attrs_ptr();
|
||||
|
@ -131,9 +131,9 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN
|
|||
if (!IsPrimitiveCNode(cnode, prim::kPrimCustom)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto func_type = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType);
|
||||
auto func_type = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType);
|
||||
// AKG/AICPU need to process attr, TBE will process later in the json creating phase.
|
||||
if (kCustomTypeAkg.find(func_type) == kCustomTypeAkg.end() || func_type == kCustomTypeAICPU) {
|
||||
return nullptr;
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
|
@ -107,7 +107,7 @@ const AnfNodePtr EliminateFuncDataType::Process(const FuncGraphPtr &func_graph,
|
|||
if (abs->isa<abstract::AbstractTuple>() && FuncDataTypeExistsInAbstractTuple(abs)) {
|
||||
RemoveInputFuncNodeForKernelGraph(kernel_graph, param);
|
||||
(void)tr.Replace(param, constant_);
|
||||
} else if (AnfAlgo::GetOutputInferDataType(param, 0) == kObjectTypeFunction) {
|
||||
} else if (common::AnfAlgo::GetOutputInferDataType(param, 0) == kObjectTypeFunction) {
|
||||
RemoveInputFuncNodeForKernelGraph(kernel_graph, param);
|
||||
(void)tr.Replace(param, constant_);
|
||||
} else {
|
||||
|
@ -125,7 +125,7 @@ const AnfNodePtr EliminateFuncDataType::Process(const FuncGraphPtr &func_graph,
|
|||
if (abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
|
||||
node->set_abstract(std::make_shared<abstract::AbstractTuple>(EliminateFuncDataTypeForAbstractTuple(abs_tuple)));
|
||||
} else if (AnfAlgo::GetOutputInferDataType(node, 0) == kObjectTypeFunction) {
|
||||
} else if (common::AnfAlgo::GetOutputInferDataType(node, 0) == kObjectTypeFunction) {
|
||||
node->set_abstract(constant_abs_);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,8 @@
|
|||
#include <utility>
|
||||
#include "utils/hash_map.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
@ -94,7 +95,7 @@ const AnfNodePtr EliminateRedundantOp::ProcessMatchedNodes(const FuncGraphPtr &f
|
|||
auto pass_size = pass_vector->size();
|
||||
for (size_t idx = 1; idx <= pass_size - 1; ++idx) {
|
||||
auto nd = (*pass_vector)[idx].first;
|
||||
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) {
|
||||
has_depend_node = true;
|
||||
}
|
||||
if (users[nd].size() > 1) {
|
||||
|
@ -140,7 +141,7 @@ void EliminateRedundantOp::Init() {
|
|||
|
||||
const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
// match the first name
|
||||
auto name1 = AnfAlgo::GetCNodeName(cnode);
|
||||
auto name1 = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto it = redundant_process_map_.find(name1);
|
||||
if (it == redundant_process_map_.end()) {
|
||||
return nullptr;
|
||||
|
@ -152,7 +153,7 @@ const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_grap
|
|||
return nullptr;
|
||||
}
|
||||
// match the second name
|
||||
auto name2 = AnfAlgo::GetCNodeName(prev_cnode);
|
||||
auto name2 = common::AnfAlgo::GetCNodeName(prev_cnode);
|
||||
if (name2 != it->second.first) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -29,7 +30,7 @@ const BaseRef EraseVisitAttr::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <string>
|
||||
#include <regex>
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -30,14 +30,14 @@ constexpr auto kCustomAttrInplaceAssignOutput = "inplace_assign_output";
|
|||
|
||||
// Used to find Custom op outputs' inplace assign index
|
||||
std::vector<std::vector<int64_t>> GetHybridInplaceIndex(const CNodePtr &cnode) {
|
||||
if (AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType) != kCustomTypeHybrid) {
|
||||
if (common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType) != kCustomTypeHybrid) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!AnfAlgo::HasNodeAttr(kCustomAttrInplaceAssignOutput, cnode)) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kCustomAttrInplaceAssignOutput, cnode)) {
|
||||
return {};
|
||||
}
|
||||
auto inplace_index_str = AnfAlgo::GetNodeAttr<std::string>(cnode, kCustomAttrInplaceAssignOutput);
|
||||
auto inplace_index_str = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kCustomAttrInplaceAssignOutput);
|
||||
std::regex delimiters(" ");
|
||||
std::vector<std::string> index(
|
||||
std::sregex_token_iterator(inplace_index_str.begin(), inplace_index_str.end(), delimiters, -1),
|
||||
|
@ -80,7 +80,7 @@ CNodePtr InsertAssign(const FuncGraphPtr &func_graph, const AnfNodePtr &src, con
|
|||
CNodePtr InsertAssignAfterCustom(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto inplace_info = GetHybridInplaceIndex(cnode);
|
||||
if (inplace_info.size() != 1) return nullptr;
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
if (auto i = LongToSize(inplace_info[0][kCustomInput]); i < input_size) {
|
||||
return InsertAssign(func_graph, cnode->input(i + 1), cnode);
|
||||
} else {
|
||||
|
@ -102,7 +102,7 @@ CNodePtr InsertAssignAfterTupleGetItem(const FuncGraphPtr &func_graph, const CNo
|
|||
auto inplace_info = GetHybridInplaceIndex(real_input);
|
||||
for (auto index : inplace_info) {
|
||||
if (index[kCustomOutput] == gt_idx && index[kCustomInput] >= 0) {
|
||||
auto custom_input_size = AnfAlgo::GetInputTensorNum(real_input);
|
||||
auto custom_input_size = common::AnfAlgo::GetInputTensorNum(real_input);
|
||||
if (auto i = LongToSize(index[kCustomInput]); i < custom_input_size) {
|
||||
return InsertAssign(func_graph, real_input->input(i + 1), cnode);
|
||||
}
|
||||
|
|
|
@ -18,13 +18,12 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -52,15 +51,15 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
|
|||
|
||||
CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
|
||||
AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
|
||||
common::AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
|
||||
auto virtual_input_op = common::AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
|
||||
if (!HasAbstractMonad(virtual_input_op)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
|
||||
auto real_input_op = common::AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(real_input_op);
|
||||
if (!real_input_op->isa<CNode>()) {
|
||||
return nullptr;
|
||||
|
@ -96,7 +95,7 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node
|
|||
if (isolated_cnode != nullptr) {
|
||||
replace_cnode = isolated_cnode;
|
||||
}
|
||||
string op_name = AnfAlgo::GetCNodeName(replace_cnode);
|
||||
string op_name = common::AnfAlgo::GetCNodeName(replace_cnode);
|
||||
// Currently we only eliminate transdata or cast nodes.
|
||||
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) {
|
||||
return nullptr;
|
||||
|
@ -115,14 +114,14 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node
|
|||
AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
bool need_update = false;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
auto input = AnfAlgo::GetInputNode(cnode, index);
|
||||
auto input = common::AnfAlgo::GetInputNode(cnode, index);
|
||||
AnfNodePtr replace_input = GetReplaceNode(func_graph, input);
|
||||
// If replace input is not null, it will be the input of the TransData or Cast.
|
||||
if (replace_input == nullptr) {
|
||||
|
@ -166,9 +165,9 @@ std::vector<size_t> SearchTransDataAndCast(const CNodePtr &cnode) {
|
|||
std::vector<size_t> result;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->input(i);
|
||||
if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimTransData) ||
|
||||
AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimCast) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimTransData) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) {
|
||||
(void)result.emplace_back(i);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include "backend/common/pass/reduce_sum_optimizer.h"
|
||||
#include <vector>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -55,7 +55,7 @@ AnfNodePtr ReduceSumOptimizer::NewRankOp(const AnfNodePtr &cnode, const KernelGr
|
|||
std::vector<AnfNodePtr> rank_inputs;
|
||||
auto prim = std::make_shared<Primitive>(prim::kPrimRank->name());
|
||||
rank_inputs.push_back(NewValueNode(prim));
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, 1);
|
||||
auto prev_node = common::AnfAlgo::GetPrevNodeOutput(cnode, 1);
|
||||
rank_inputs.push_back(prev_node.first);
|
||||
auto rank_op = NewCNode(rank_inputs, kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(rank_op);
|
||||
|
@ -99,7 +99,7 @@ AnfNodePtr ReduceSumOptimizer::InsertAssistNode(const CNodePtr &cnode, const Ker
|
|||
auto rank_op = NewRankOp(cnode, kernel_graph);
|
||||
// new range op
|
||||
auto range_op = NewRangeOp(rank_op, kernel_graph);
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
new_inputs.push_back(cnode->input(1));
|
||||
new_inputs.push_back(range_op);
|
||||
auto new_node = NewCNode(cnode, kernel_graph);
|
||||
|
@ -118,7 +118,7 @@ AnfNodePtr ReduceSumOptimizer::NewAssistValueNode(const CNodePtr &cnode, const K
|
|||
// axis is a tuple ,maybe empty or contain a value less 0;
|
||||
auto axis_input = cnode->input(axis_input_index);
|
||||
if (IsValueNode<ValueTuple>(axis_input)) {
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
std::vector<AnfNodePtr> new_inputs = {common::AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
new_inputs.push_back(cnode->input(1));
|
||||
auto value_node = axis_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
|
@ -167,16 +167,16 @@ const AnfNodePtr ReduceSumOptimizer::Process(const FuncGraphPtr &func_graph, con
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != kReduceSumOpName) {
|
||||
MS_LOG(DEBUG) << "Current node is not: " << kReduceSumOpName << ", skip!";
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::IsDynamicShape(cnode)) {
|
||||
if (!common::AnfAlgo::IsDynamicShape(cnode)) {
|
||||
MS_LOG(DEBUG) << "Current node is not dynamic shape, skip!";
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
if (AnfUtils::IsDimUnknown(cnode) && IsNeedComputeRank(cnode)) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -30,16 +31,16 @@ kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNo
|
|||
std::vector<TypeId> outputs_device_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
outputs_shape.push_back(common::AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
}
|
||||
builder.SetFusionType(AnfAlgo::GetFusionType(cnode));
|
||||
builder.SetProcessor(AnfAlgo::GetProcessor(cnode));
|
||||
|
@ -58,7 +59,7 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
for (auto node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
if (node != nullptr && node->isa<CNode>() && common::AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -74,8 +75,8 @@ bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) {
|
|||
proxy_node->set_kernel_info(kernel_info);
|
||||
|
||||
AbstractBasePtrList abstract_list;
|
||||
AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
|
||||
AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node);
|
||||
common::AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node);
|
||||
abstract_list.push_back(cnode->abstract());
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "backend/common/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -18,11 +18,10 @@
|
|||
#include <memory>
|
||||
#include "backend/common/pass/sparse_process.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -102,7 +101,7 @@ bool SplitParameter(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs,
|
|||
|
||||
bool SplitCNode(const AnfNodePtr &node, std::vector<AnfNodePtr> *new_inputs) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto sparse_prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto sparse_prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(sparse_prim);
|
||||
// Currently, only MakeCSR and MakeTuple nodes can be split.
|
||||
if (make_sparse_set.count(sparse_prim->name()) <= 0 && sparse_prim->name().compare(prim::kPrimMakeTuple->name()) != 0)
|
||||
|
@ -140,7 +139,7 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::string prim_name = prim->name();
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -34,13 +34,12 @@
|
|||
#include "kernel/kernel.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/contract.h"
|
||||
#include "include/common/utils/contract.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
|
||||
using DeviceAddress = device::DeviceAddress;
|
||||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
using Address = kernel::Address;
|
||||
|
@ -72,67 +71,6 @@ class AnfRuntimeAlgorithm {
|
|||
public:
|
||||
static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg);
|
||||
static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter);
|
||||
// get real input node of tuple_get_item
|
||||
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
|
||||
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
|
||||
// get input_anf_node's real kernel by recurse
|
||||
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
|
||||
static KernelWithIndex VisitKernelWithReturnType(
|
||||
const AnfNodePtr &input_anf_node, size_t output_index, bool skip_nop_node = false,
|
||||
const std::vector<PrimitivePtr> &return_types = {prim::kPrimMakeTuple},
|
||||
abstract::AbstractBasePtr *abstract = nullptr);
|
||||
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
|
||||
const std::vector<PrimitivePtr> &return_types = {});
|
||||
static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
|
||||
// get cnode primitive
|
||||
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
|
||||
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
|
||||
static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
|
||||
// check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
|
||||
static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
|
||||
// get cnode primitive
|
||||
static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
|
||||
// get kernel_name of anf node
|
||||
static std::string GetCNodeName(const AnfNodePtr &node);
|
||||
// get detail info of anf node
|
||||
static std::string GetNodeDebugString(const AnfNodePtr &node);
|
||||
// get attr of anf node
|
||||
template <typename T>
|
||||
static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
std::string node_debug_log = node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
|
||||
}
|
||||
// single op cnode.
|
||||
if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
|
||||
return GetValue<T>(primitive->GetAttr(key));
|
||||
}
|
||||
// graph kernel cnode.
|
||||
auto fg = GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
return GetValue<T>(fg->get_attr(key));
|
||||
}
|
||||
static bool IsTupleOutput(const AnfNodePtr &anf);
|
||||
// set attr of anf node
|
||||
static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
|
||||
// set attr of key from 'from' node to 'to' node
|
||||
static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
|
||||
// set a new key for attr from 'from' node to 'to' node
|
||||
static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
|
||||
const AnfNodePtr &to);
|
||||
// set all attrs from 'from' node to 'to' node
|
||||
static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
|
||||
// check whether a cnode has the specified attr.
|
||||
static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
|
||||
// delete attr of anf node
|
||||
static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
|
||||
// get the num of inputs include monads for a cnode
|
||||
static size_t GetInputNum(const CNodePtr &cnode);
|
||||
// get the num of inputs exclude monads for real_kernel (which can be build and run in device)
|
||||
static size_t GetInputTensorNum(const AnfNodePtr &node);
|
||||
// get the num of output real_kernel(which can be build and run in device)
|
||||
static size_t GetOutputTensorNum(const AnfNodePtr &node);
|
||||
// Get the memory size of output tensor of node.
|
||||
static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index);
|
||||
// get all outputs format select of anf node
|
||||
|
@ -149,18 +87,10 @@ class AnfRuntimeAlgorithm {
|
|||
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input format select of anf node
|
||||
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get prev node output width output index
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
|
||||
// get output format from prev node,input_index is the input index of current node related to prev node
|
||||
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get reshape_type of from the output of input node.
|
||||
static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes inferred by ME from input nodes.
|
||||
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
|
||||
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
|
||||
size_t output_idx);
|
||||
// get input shapes inferred by ME from input nodes.
|
||||
static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes which will built and run in device
|
||||
static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input shapes which will built and run in device
|
||||
|
@ -175,11 +105,6 @@ class AnfRuntimeAlgorithm {
|
|||
static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// Get Output Padding Axis
|
||||
static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get output data type inferred by ME of anf node
|
||||
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
|
||||
static TypeId GetOutputInferDataType(const TypePtr &type_ptr, size_t output_idx);
|
||||
// get output original data type from prev node,input_index is the input index of current node related to prev node
|
||||
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output select data type of anf node
|
||||
static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input select data type of anf node
|
||||
|
@ -211,15 +136,6 @@ class AnfRuntimeAlgorithm {
|
|||
static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
|
||||
// get workspace device mutable addr of anf_node
|
||||
static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index);
|
||||
// set infer shapes and types of anf node
|
||||
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
|
||||
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
|
||||
// get and set output shape ptr
|
||||
static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
|
||||
static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
|
||||
static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
|
||||
const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
|
||||
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
|
||||
// get op pattern of the node
|
||||
static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
|
||||
// get KernelBuildType of node ,such as ATT,RT,FWK and so on
|
||||
|
@ -239,18 +155,6 @@ class AnfRuntimeAlgorithm {
|
|||
static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
|
||||
// set kernel mod
|
||||
static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
|
||||
// checkout whether the anf node is a graph kernel.
|
||||
static bool IsGraphKernel(const AnfNodePtr &node);
|
||||
// checkout whether the anf node is an inner node of graph kernel.
|
||||
static bool IsNodeInGraphKernel(const AnfNodePtr &node);
|
||||
// get the real output of GraphKernel.
|
||||
static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
|
||||
// check parameter is weight or data
|
||||
static bool IsParameterWeight(const ParameterPtr &node);
|
||||
// checkout whether the anf node is include the label_index.
|
||||
static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
|
||||
// Check whether the cnode update parameter
|
||||
static bool IsUpdateParameterKernel(const CNodePtr &node);
|
||||
// set stream id of kernel,which will be set in stream assign and be used in stream generate
|
||||
static void SetStreamId(uint32_t stream_id, AnfNode *node);
|
||||
// get stream id
|
||||
|
@ -263,7 +167,6 @@ class AnfRuntimeAlgorithm {
|
|||
static void SetGraphId(uint32_t graph_id, AnfNode *node);
|
||||
// get graph id
|
||||
static uint32_t GetGraphId(const AnfNode *node);
|
||||
static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
|
||||
// charge if the node's output is a feature map output
|
||||
static bool IsFeatureMapOutput(const AnfNodePtr &node);
|
||||
// charge if the node's input is from a feature map output
|
||||
|
@ -272,103 +175,19 @@ class AnfRuntimeAlgorithm {
|
|||
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||
// get me input index for some tbe ops which input order is different between me and tbe impl
|
||||
static size_t GetOriginalInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
|
||||
static bool IsCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsFusedCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
|
||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||
static bool IsNeedSkipNopOpAddr(const AnfNodePtr &node);
|
||||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
|
||||
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
|
||||
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
|
||||
static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
|
||||
static void ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list);
|
||||
// get fix output precision of cnode.
|
||||
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
|
||||
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
||||
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
|
||||
static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
|
||||
static bool IsDynamicShape(const AnfNodePtr &node);
|
||||
static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
|
||||
static bool IsCondControlKernel(const CNodePtr &node);
|
||||
static bool IsIndependentNode(const CNodePtr &node);
|
||||
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
||||
static std::optional<string> GetDumpFlag(const AnfNodePtr &node);
|
||||
static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
|
||||
static std::vector<int64_t> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<int64_t> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<int64_t> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||
static bool IsHostKernel(const CNodePtr &node);
|
||||
static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr);
|
||||
// return true if use cnode_input's abstract, false if use real_input's abstract
|
||||
static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &cnode_input,
|
||||
const AnfNodePtr &real_input, size_t index);
|
||||
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
// Find real input nodes.
|
||||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
static AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraph &graph);
|
||||
static void InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph);
|
||||
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
|
||||
// Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
|
||||
static void GetRealInputs(const AnfNodePtr &anf_node, std::vector<session::KernelWithIndex> *inputs);
|
||||
// Check whether tensors need broadcast or not.
|
||||
static bool IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs);
|
||||
// Calc tensor size in byte.
|
||||
template <typename T>
|
||||
static size_t TensorSizeInByte(const std::vector<size_t> &shape) {
|
||||
size_t result = sizeof(T);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
result *= shape[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Judge a control operator need be compiled into kernel graph rather than be cut into single op and
|
||||
// executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
|
||||
// in backend in PyNative mode.
|
||||
static bool IsControlOpExecInBackend(const AnfNodePtr &node);
|
||||
|
||||
static bool IsNodeInputContainMonad(const AnfNodePtr &node);
|
||||
// Check if node is non-task op.
|
||||
static bool IsNonTaskOp(const CNodePtr &node);
|
||||
// Check if node has none input after IR fusion.
|
||||
static bool IsNoneInput(const AnfNodePtr &node, size_t index);
|
||||
// Save inputs/outputs/workspace address in kernel_mod.
|
||||
static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
|
||||
static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||
// Check whether node is a call node, call nodes are those cnodes whose first input is not primitive node.
|
||||
static bool IsCallNode(const AnfNodePtr &node);
|
||||
// Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
|
||||
static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
|
||||
// Get attr groups
|
||||
static int64_t GetAttrGroups(const AnfNodePtr &node, size_t index);
|
||||
|
||||
static inline bool IsAllgather(const CNodePtr &cnode) { return GetCNodeName(cnode) == kAllGatherOpName; }
|
||||
|
||||
static inline bool IsFusion(const CNodePtr &cnode) {
|
||||
return HasNodeAttr(kAttrFusion, cnode) && GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
|
||||
}
|
||||
|
||||
static inline bool IsFromParallelOptimizer(const CNodePtr &cnode) {
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
return (primitive != nullptr) && primitive->instance_name().find("parallel_optimizer") != std::string::npos;
|
||||
}
|
||||
|
||||
static inline bool IsRecompute(const CNodePtr &cnode) {
|
||||
auto attr_dup = cnode->GetAttr(kAttrDuplicated);
|
||||
return attr_dup != nullptr && GetValue<bool>(attr_dup);
|
||||
}
|
||||
|
||||
static void UpdateGraphValidRefPair(const KernelGraphPtr &graph);
|
||||
// Get the real output node and indexes of get item, make tuple, depend, load.
|
||||
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *index_stack);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "pipeline/jit/base.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_select_ascend.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -121,14 +122,14 @@ void DumpExecuteOrder(const NotNull<KernelGraphPtr> kg) {
|
|||
for (auto &cnode : kg->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimLabelSet)) {
|
||||
fout << "L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) << ":\n";
|
||||
fout << "L" << common::AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) << ":\n";
|
||||
}
|
||||
fout << " [" << index << "], " << cnode->DebugString();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
|
||||
fout << " : L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
|
||||
fout << " : L" << common::AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex);
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) {
|
||||
auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) {
|
||||
auto labels = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
|
||||
fout << " : ";
|
||||
for (size_t i = 0; i < labels.size(); ++i) {
|
||||
fout << ((i > 0) ? ", L" : "L") << labels[i];
|
||||
|
@ -419,11 +420,11 @@ class CallInfoFinder {
|
|||
// Found a node with UMonad abstract, set it as the last monad.
|
||||
last_monad = node;
|
||||
call_info->return_monad_ = last_monad;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
MakeCallSite(node->cast<CNodePtr>(), last_monad, call_info);
|
||||
call_info->return_monad_ = nullptr;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
|
||||
MakeSwitchCallSite(node->cast<CNodePtr>(), last_monad, call_info);
|
||||
call_info->return_monad_ = nullptr;
|
||||
}
|
||||
|
@ -731,8 +732,8 @@ class AscendAutoMonadConverter {
|
|||
kernel_graph_->SetExecOrderByDefault();
|
||||
if (call_info_.recursive) {
|
||||
const auto &nodes = kernel_graph_->execution_order();
|
||||
AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin());
|
||||
AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin());
|
||||
}
|
||||
for (auto &call_site : call_info_.call_sites) {
|
||||
if (need_stackops_ && call_site.recursive) {
|
||||
|
@ -762,7 +763,7 @@ class AscendAutoMonadConverter {
|
|||
if (!AnfUtils::IsRealCNodeKernel(*iter)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
|
||||
const auto &last_call_site = context->call_info_map[kg].call_sites.back();
|
||||
for (auto &branch : last_call_site.callees) {
|
||||
if (memo.find(branch.graph) != memo.end()) {
|
||||
|
@ -772,7 +773,7 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
break;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
|
||||
common::AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
|
||||
MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
|
||||
return;
|
||||
}
|
||||
|
@ -792,16 +793,17 @@ class AscendAutoMonadConverter {
|
|||
if (!AnfUtils::IsRealCNodeKernel(*iter)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) && AnfAlgo::HasNodeAttr(kAttrReturn, *iter)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) &&
|
||||
common::AnfAlgo::HasNodeAttr(kAttrReturn, *iter)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) ||
|
||||
AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSwitch) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
|
||||
MS_LOG(INFO) << "this node is Labelxxxx, do not found iter end.";
|
||||
break;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
|
||||
common::AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
|
||||
MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
|
||||
return;
|
||||
}
|
||||
|
@ -858,7 +860,7 @@ class AscendAutoMonadConverter {
|
|||
std::vector<CNodePtr> stack_pushs;
|
||||
bool find_call_point = false;
|
||||
for (auto &node : exec_order) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
auto node_name = common::AnfAlgo::GetCNodeName(node);
|
||||
if (node == call_point) {
|
||||
find_call_point = true;
|
||||
continue;
|
||||
|
@ -887,7 +889,7 @@ class AscendAutoMonadConverter {
|
|||
std::vector<CNodePtr> *stack_pushs) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
uint32_t start_index = 1;
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) {
|
||||
start_index = kInputIndex;
|
||||
}
|
||||
for (uint32_t i = start_index; i < node->inputs().size(); i++) {
|
||||
|
@ -998,7 +1000,7 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// For Switch, we reverse the graphes and labels, so that the false branch
|
||||
// is the first one, since for kernel LabelSwitch, false is the first branch.
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
std::reverse(graphes.begin(), graphes.end());
|
||||
std::reverse(labels.begin(), labels.end());
|
||||
}
|
||||
|
@ -1007,7 +1009,7 @@ class AscendAutoMonadConverter {
|
|||
auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);
|
||||
call_site->conversion_cnode = label_goto_switch;
|
||||
if (call_site->recursive) {
|
||||
AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch);
|
||||
}
|
||||
|
||||
// Setup return label and output if required.
|
||||
|
@ -1082,15 +1084,15 @@ class AscendAutoMonadConverter {
|
|||
CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &graphes,
|
||||
const std::vector<uint32_t> &labels) {
|
||||
// Create LabelGoto or LabelSwitch according the cnode type.
|
||||
const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall);
|
||||
const bool is_call = common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall);
|
||||
auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels));
|
||||
|
||||
// Set child graph attribute for the LabelGoto or LabelSwitch node.
|
||||
SetChildGrapAttr(label_goto_switch, graphes);
|
||||
|
||||
// Mark the label_switch node is for 'switch_layer' if it is.
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch);
|
||||
}
|
||||
return label_goto_switch;
|
||||
}
|
||||
|
@ -1114,7 +1116,7 @@ class AscendAutoMonadConverter {
|
|||
// Insert label_goto for return.
|
||||
auto &return_point = return_points.front();
|
||||
auto return_goto = LabelGoto(return_point.call_site->return_label);
|
||||
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
|
||||
kernel_graph_->set_end_goto(return_goto);
|
||||
return;
|
||||
}
|
||||
|
@ -1128,9 +1130,9 @@ class AscendAutoMonadConverter {
|
|||
auto &label_param = call_info_.label_param;
|
||||
MS_EXCEPTION_IF_NULL(label_param);
|
||||
auto return_switch = LabelSwitch(label_param, return_labels);
|
||||
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
|
||||
if (!call_info_.recursive) {
|
||||
AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch);
|
||||
}
|
||||
kernel_graph_->set_end_goto(return_switch);
|
||||
context_.SetSubGraphMultiCall(true);
|
||||
|
@ -1249,29 +1251,29 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// AissgnAll support tuple to tuple assign.
|
||||
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
|
||||
// Assign single value.
|
||||
return Assign(target, source, link, keep, output);
|
||||
}
|
||||
// Assign tuple.
|
||||
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target);
|
||||
std::vector<AnfNodePtr> sources = AnfAlgo::GetAllOutput(source);
|
||||
std::vector<AnfNodePtr> targets = common::AnfAlgo::GetAllOutput(target);
|
||||
std::vector<AnfNodePtr> sources = common::AnfAlgo::GetAllOutput(source);
|
||||
if (targets.size() != sources.size()) {
|
||||
MS_LOG(EXCEPTION) << "Target size " << targets.size() << " != source size " << sources.size();
|
||||
}
|
||||
AnfNodePtrList tuple_inputs;
|
||||
auto source_item_with_index = AnfAlgo::VisitKernelWithReturnType(source, 0);
|
||||
auto source_item_with_index = common::AnfAlgo::VisitKernelWithReturnType(source, 0);
|
||||
MS_EXCEPTION_IF_NULL(source_item_with_index.first);
|
||||
auto source_cnode = source_item_with_index.first->cast<CNodePtr>();
|
||||
auto target_cnode = target->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(source_cnode);
|
||||
MS_EXCEPTION_IF_NULL(target_cnode);
|
||||
if (!AnfAlgo::CheckPrimitiveType(source_cnode, prim::kPrimMakeTuple)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(source_cnode, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(WARNING) << "Source : " << source_cnode->DebugString() << " is not MakeTuple.";
|
||||
}
|
||||
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t i = 1; i < target_cnode->inputs().size(); ++i) {
|
||||
if (AnfAlgo::IsTupleOutput(target_cnode->input(i))) {
|
||||
if (common::AnfAlgo::IsTupleOutput(target_cnode->input(i))) {
|
||||
tuple_inputs.emplace_back(AssignAll(target_cnode->input(i), source_cnode->input(i), link, keep, output));
|
||||
} else {
|
||||
tuple_inputs.emplace_back(Assign(target_cnode->input(i), source_cnode->input(i), link, keep, output));
|
||||
|
@ -1354,7 +1356,7 @@ class AscendAutoMonadConverter {
|
|||
auto monad = GetMonad();
|
||||
auto label_goto = NewPrimitive(prim::kPrimLabelGoto);
|
||||
auto cnode = kernel_graph_->NewCNode({label_goto, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
monad_ = cnode;
|
||||
return cnode;
|
||||
|
@ -1365,7 +1367,7 @@ class AscendAutoMonadConverter {
|
|||
auto monad = GetMonad();
|
||||
auto label_set = NewPrimitive(prim::kPrimLabelSet);
|
||||
auto cnode = kernel_graph_->NewCNode({label_set, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
monad_ = cnode;
|
||||
return cnode;
|
||||
|
@ -1377,7 +1379,7 @@ class AscendAutoMonadConverter {
|
|||
auto label_switch = NewPrimitive(prim::kPrimLabelSwitch);
|
||||
auto cnode = kernel_graph_->NewCNode({label_switch, cond, monad});
|
||||
auto label_list = MakeValue(labels);
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, label_list, cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, label_list, cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
monad_ = cnode;
|
||||
return cnode;
|
||||
|
@ -1385,7 +1387,7 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// Set child graph attribute for label_goto/label_switch node.
|
||||
void SetChildGrapAttr(const AnfNodePtr &node, const std::vector<KernelGraphPtr> &graphs) {
|
||||
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
|
||||
}
|
||||
|
||||
// Make a StackInit node.
|
||||
|
@ -1393,7 +1395,7 @@ class AscendAutoMonadConverter {
|
|||
auto monad = AnfAlgo::MakeMonadValueNode(kg);
|
||||
auto stack_init = NewPrimitive(prim::kPrimStackInit);
|
||||
auto cnode = kg->NewCNode({stack_init, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1403,7 +1405,7 @@ class AscendAutoMonadConverter {
|
|||
auto monad = AnfAlgo::MakeMonadValueNode(kg);
|
||||
auto stack_destroy = NewPrimitive(prim::kPrimStackDestroy);
|
||||
auto cnode = kg->NewCNode({stack_destroy, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1413,9 +1415,9 @@ class AscendAutoMonadConverter {
|
|||
auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
|
||||
auto stack_push = NewPrimitive(prim::kPrimStackPush);
|
||||
auto cnode = kernel_graph_->NewCNode({stack_push, input, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_push_" + std::to_string(name_index_++);
|
||||
AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1425,9 +1427,9 @@ class AscendAutoMonadConverter {
|
|||
auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
|
||||
auto stack_pop = NewPrimitive(prim::kPrimStackPop);
|
||||
auto cnode = kernel_graph_->NewCNode({stack_pop, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
|
||||
auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_pop_" + std::to_string(name_index_++);
|
||||
AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
|
||||
cnode->set_abstract(monad->abstract()); // need to refresh output's abstract().
|
||||
return cnode;
|
||||
}
|
||||
|
@ -1482,8 +1484,8 @@ class ExecuteOrderGenerator {
|
|||
uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) {
|
||||
uint32_t max_label = 0;
|
||||
for (auto &node : nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
max_label = std::max(label_id, max_label);
|
||||
}
|
||||
}
|
||||
|
@ -1493,7 +1495,7 @@ class ExecuteOrderGenerator {
|
|||
void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
|
||||
std::multimap<uint32_t, uint32_t> *labels_multimap) {
|
||||
bool is_new_labels = false;
|
||||
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
|
||||
auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
|
||||
std::vector<uint32_t> new_labels;
|
||||
new_labels.reserve(label_list.size());
|
||||
for (auto label_id : label_list) {
|
||||
|
@ -1511,19 +1513,19 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
(void)switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end());
|
||||
if (is_new_labels) {
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node);
|
||||
}
|
||||
}
|
||||
|
||||
void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
|
||||
std::multimap<uint32_t, uint32_t> *labels_multimap) {
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id);
|
||||
if (iter == switch_labels->end()) {
|
||||
(void)labels->emplace_back(label_id);
|
||||
return;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node);
|
||||
(void)labels_multimap->emplace(*iter, max_label_);
|
||||
(void)labels->emplace_back(max_label_);
|
||||
}
|
||||
|
@ -1536,11 +1538,11 @@ class ExecuteOrderGenerator {
|
|||
std::multimap<uint32_t, uint32_t> labels_multimap;
|
||||
max_label_ = FindMaxLabelId(nodes);
|
||||
for (auto &node : nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap);
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap);
|
||||
continue;
|
||||
}
|
||||
|
@ -1555,10 +1557,10 @@ class ExecuteOrderGenerator {
|
|||
auto old_label = labels.first;
|
||||
auto new_label = labels.second;
|
||||
auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
|
||||
return false;
|
||||
}
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
return label_id == old_label;
|
||||
});
|
||||
if (iter == nodes->end()) {
|
||||
|
@ -1566,8 +1568,8 @@ class ExecuteOrderGenerator {
|
|||
}
|
||||
auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name()));
|
||||
auto cnode = graph_->NewCNode({label_set});
|
||||
AnfAlgo::CopyNodeAttrs(*iter, cnode);
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode);
|
||||
common::AnfAlgo::CopyNodeAttrs(*iter, cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode);
|
||||
auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
(void)device::ascend::SelectKernelInfo(cnode);
|
||||
|
@ -1580,10 +1582,10 @@ class ExecuteOrderGenerator {
|
|||
execution_order->insert(execution_order->end(), order.begin(), order.end());
|
||||
}
|
||||
|
||||
bool HasSubGraphs(const CNodePtr &cnode) { return (cnode && AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)); }
|
||||
bool HasSubGraphs(const CNodePtr &cnode) { return (cnode && common::AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)); }
|
||||
|
||||
std::vector<KernelGraphPtr> GetSubGraphs(const CNodePtr &cnode) {
|
||||
return AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
|
||||
return common::AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
|
||||
}
|
||||
|
||||
void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) {
|
||||
|
@ -1612,7 +1614,7 @@ class ExecuteOrderGenerator {
|
|||
// and then append them to current execution order list.
|
||||
if (HasSubGraphs(cnode)) {
|
||||
auto sub_graphs = GetSubGraphs(cnode);
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
|
||||
// For Switch, we use reversed order to generate sub-graph's execution order,
|
||||
// because the true branch of LabelSwitch is the second one, but
|
||||
// we want to make true branch ahead of false branch in the generated
|
||||
|
@ -1628,7 +1630,7 @@ class ExecuteOrderGenerator {
|
|||
AppendGraphOrder(&execution_order, sub_graph);
|
||||
}
|
||||
// Clear ChildGraph attribute after execute order generated.
|
||||
AnfAlgo::EraseNodeAttr(kAttrChildGraph, cnode);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrChildGraph, cnode);
|
||||
}
|
||||
}
|
||||
// Save generated execution order into the graph.
|
||||
|
@ -1762,10 +1764,10 @@ class ExecuteOrderGenerator {
|
|||
return {p.first.first, {p.first.second, p.second.first, p.second.second}};
|
||||
});
|
||||
auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
|
||||
if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::kPrimTransData)) {
|
||||
if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimTransData)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto first_input = AnfAlgo::VisitKernelWithReturnType(cnode->input(kFirstDataInputIndex), 0, true);
|
||||
auto first_input = common::AnfAlgo::VisitKernelWithReturnType(cnode->input(kFirstDataInputIndex), 0, true);
|
||||
MS_EXCEPTION_IF_NULL(first_input.first);
|
||||
return first_input.first;
|
||||
}
|
||||
|
@ -1794,7 +1796,7 @@ class ExecuteOrderGenerator {
|
|||
(void)refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
|
||||
}
|
||||
for (auto &in : node->inputs()) {
|
||||
auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
|
||||
auto visit_node = common::AnfAlgo::VisitKernelWithReturnType(in, 0).first;
|
||||
visit_node = validate_ref_parameter(visit_node);
|
||||
if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
|
||||
continue;
|
||||
|
@ -1836,16 +1838,16 @@ class ExecuteOrderGenerator {
|
|||
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
|
||||
auto &node = *iter;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
|
||||
auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
|
||||
auto labels = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
|
||||
for (auto label : labels) {
|
||||
label_used.insert(label);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
|
||||
auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto label = common::AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto next = std::next(iter);
|
||||
if (next != exec_order.end() && IsPrimitiveCNode(*next, prim::kPrimLabelSet)) {
|
||||
// The LabelGoto that jump to next node can be removed.
|
||||
auto next_label = AnfAlgo::GetNodeAttr<uint32_t>(*next, kAttrLabelIndex);
|
||||
auto next_label = common::AnfAlgo::GetNodeAttr<uint32_t>(*next, kAttrLabelIndex);
|
||||
if (next_label == label) {
|
||||
iter = exec_order.erase(iter);
|
||||
continue;
|
||||
|
@ -1859,7 +1861,7 @@ class ExecuteOrderGenerator {
|
|||
for (auto iter = exec_order.begin(); iter != exec_order.end();) {
|
||||
auto &node = *iter;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelSet)) {
|
||||
auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
auto label = common::AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
|
||||
if (label_used.find(label) == label_used.end()) {
|
||||
iter = exec_order.erase(iter);
|
||||
continue;
|
||||
|
|
|
@ -22,9 +22,10 @@
|
|||
#include "ir/param_info.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/ms_device_shape_transfer.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -45,7 +46,7 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
if (!common::AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor = inputs[no_weight_input++];
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
|
||||
|
@ -71,7 +72,7 @@ GraphId AscendInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_grap
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
if (common::AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
const auto ¶m_value = pk_node->default_param();
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
|
||||
|
@ -101,7 +102,7 @@ bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vect
|
|||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
if (!common::AnfAlgo::IsParameterWeight(parameter)) {
|
||||
paras.push_back(parameter);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "base/base_ref_utils.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/ms_device_shape_transfer.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_select_ascend.h"
|
||||
#include "plugin/device/ascend/hal/device/kernel_build_ascend.h"
|
||||
|
@ -38,12 +38,13 @@
|
|||
#include "runtime/device/kernel_adjust.h"
|
||||
#include "plugin/device/ascend/hal/device/ascend_stream_assign.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "debug/data_dump/e2e_dump.h"
|
||||
|
@ -179,7 +180,7 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
if (cnode_refcount.find(kernel_with_index) == cnode_refcount.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &output_kernel_with_index = AnfAlgo::VisitKernel(output, 0);
|
||||
const auto &output_kernel_with_index = common::AnfAlgo::VisitKernel(output, 0);
|
||||
const auto &output_node = output_kernel_with_index.first;
|
||||
const auto &output_index = output_kernel_with_index.second;
|
||||
auto out_abstract = output_node->abstract();
|
||||
|
@ -190,7 +191,7 @@ void GenOpOutputStubTensor(const KernelGraphPtr &single_op_graph, const CNodePtr
|
|||
}
|
||||
abstract::AbstractTensorPtr tensor_abstract = out_abstract->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
const auto &infer_type = AnfAlgo::GetOutputInferDataType(output_node, output_index);
|
||||
const auto &infer_type = common::AnfAlgo::GetOutputInferDataType(output_node, output_index);
|
||||
tensor::TensorPtr stub_output_tensor =
|
||||
std::make_shared<tensor::Tensor>(infer_type, tensor_abstract->shape()->shape(), nullptr);
|
||||
const auto &output_type = AnfAlgo::GetOutputDeviceDataType(output_node, output_index);
|
||||
|
@ -255,7 +256,7 @@ bool TensorNeedSync(const std::shared_ptr<KernelGraph> &kernel_graph, const AnfN
|
|||
#endif
|
||||
auto input_param = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_param);
|
||||
if (AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
if (common::AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
if (kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
|
@ -340,8 +341,8 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
auto tensor_shape = tensor->shape();
|
||||
std::vector<size_t> shape_tmp;
|
||||
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), LongToSize);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
if (AnfAlgo::OutputAddrExist(input_node, 0) &&
|
||||
|
@ -362,7 +363,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
|
||||
AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
common::AnfAlgo::IsParameterWeight(input_param) || kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
if (kernel_graph->IsUpdatedParameter(input_param)) {
|
||||
|
@ -771,7 +772,7 @@ void AscendSession::PrepareForOutputTensor(const KernelGraphPtr &graph,
|
|||
void StoreCNodePrimitive(const KernelGraphPtr &graph) {
|
||||
const auto &nodes = graph->execution_order();
|
||||
for (auto &node : nodes) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto new_primitive = std::make_shared<Primitive>(*primitive);
|
||||
node->set_input(kAnfPrimitiveIndex, NewValueNode(new_primitive));
|
||||
|
@ -908,7 +909,7 @@ void AscendSession::CacheCNodeOutputInfo(const KernelGraph &graph) const {
|
|||
std::vector<std::string> formats;
|
||||
std::vector<TypeId> types;
|
||||
std::vector<size_t> tensor_sizes;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
auto output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
|
||||
|
@ -930,12 +931,12 @@ void AscendSession::CacheCNodeOutputInfo(const KernelGraph &graph) const {
|
|||
std::vector<std::string> formats;
|
||||
std::vector<TypeId> types;
|
||||
std::vector<size_t> tensor_sizes;
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(input);
|
||||
auto output_size = common::AnfAlgo::GetOutputTensorNum(input);
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
auto format = AnfAlgo::GetOutputFormat(input, index);
|
||||
auto type_id = AnfAlgo::GetOutputDeviceDataType(input, index);
|
||||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(input, index);
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(input, index);
|
||||
}
|
||||
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input, index);
|
||||
formats.emplace_back(format);
|
||||
|
@ -952,10 +953,10 @@ void AscendSession::GetOpInputStubTensors(const CNodePtr &cnode, const std::map<
|
|||
InputTensorInfo *input_tensor_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_info);
|
||||
const auto input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
const auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 1; i <= input_tensor_num; i += 1) {
|
||||
const auto &input = cnode->input(i);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
|
@ -1109,7 +1110,7 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(INFO) << "Status record: start adjust kernel. graph id: " << kernel_graph->graph_id();
|
||||
opt::HideNopNode(kernel_graph.get());
|
||||
auto execution_order = kernel_graph->execution_order();
|
||||
AnfAlgo::ReorderExecList(NOT_NULL(&execution_order));
|
||||
common::AnfAlgo::ReorderExecList(NOT_NULL(&execution_order));
|
||||
kernel_graph->set_execution_order(execution_order);
|
||||
// Insert CLearZero op
|
||||
// prepare for next step from json get atomic info
|
||||
|
@ -1170,7 +1171,7 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &kernels = kernel_graph->execution_order();
|
||||
auto iter = std::find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) {
|
||||
return AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape);
|
||||
return common::AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape);
|
||||
});
|
||||
if (iter == kernels.end()) {
|
||||
return;
|
||||
|
@ -1189,7 +1190,7 @@ static CNodePtr GetNextLabelSet(const std::vector<CNodePtr> &kernel_nodes, uint3
|
|||
MS_LOG(EXCEPTION) << "there is no node after this node:" << kernel_nodes[index]->DebugString();
|
||||
}
|
||||
auto kernel = kernel_nodes[index + 1];
|
||||
if (AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != kLabelSetOpName) {
|
||||
MS_LOG(EXCEPTION) << "the node is not labelset follow labelgoto/labelswitch, node: "
|
||||
<< kernel_nodes[index]->DebugString();
|
||||
}
|
||||
|
@ -1210,14 +1211,14 @@ static std::vector<CNodePtr> HandleRecursiveCall(const std::vector<CNodePtr> &ke
|
|||
} else {
|
||||
back->emplace_back(kernel_cnodes[i]);
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursiveEnd, kernel_cnodes[i])) {
|
||||
*index = i;
|
||||
back->insert(back->end(), back_temp.begin(), back_temp.end());
|
||||
return front;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
back_flag = true;
|
||||
if (!AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
|
||||
if (!common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], back_label)) {
|
||||
auto temp = HandleRecursiveCall(kernel_cnodes, back_label, &(++i), &back_temp);
|
||||
front.insert(front.end(), temp.begin(), temp.end());
|
||||
}
|
||||
|
@ -1236,11 +1237,11 @@ static void UnfoldRecursiveExecOrder(KernelGraph *kernel_graph) {
|
|||
std::vector<CNodePtr> mem_reuse_order;
|
||||
mem_reuse_order.reserve(kernel_cnodes.size());
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrRecursiveStart, kernel_cnodes[i])) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
continue;
|
||||
}
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
std::vector<CNodePtr> back;
|
||||
auto front = HandleRecursiveCall(kernel_cnodes, label_id, &i, &back);
|
||||
mem_reuse_order.insert(mem_reuse_order.end(), front.begin(), front.end());
|
||||
|
@ -1253,11 +1254,11 @@ static void GetSubGraphExecOrder(const KernelGraph *kernel_graph, uint32_t index
|
|||
std::vector<CNodePtr> *mem_reuse_order) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(mem_reuse_order);
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(back_node, kAttrLabelIndex);
|
||||
auto kernel_cnodes = kernel_graph->execution_order();
|
||||
for (auto i = index; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order->emplace_back(kernel_cnodes[i]);
|
||||
if (AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
|
||||
if (common::AnfAlgo::IsLabelIndexInNode(kernel_cnodes[i], label_id)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -1273,10 +1274,10 @@ void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
|||
std::vector<CNodePtr> mem_reuse_order;
|
||||
for (uint32_t i = 0; i < kernel_cnodes.size(); i++) {
|
||||
mem_reuse_order.emplace_back(kernel_cnodes[i]);
|
||||
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSwitch) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(kernel_cnodes[i], kAttrLabelSwitchList);
|
||||
for (auto label_id : label_list) {
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
|
@ -1286,10 +1287,10 @@ void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelGoto) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i]) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrReturn, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) == label_id_index_map.end()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1297,9 +1298,9 @@ void InitMemReuseExecOrder(KernelGraph *kernel_graph) {
|
|||
GetSubGraphExecOrder(kernel_graph, label_id_index_map[label_id], back_node, &mem_reuse_order);
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
|
||||
!AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(kernel_cnodes[i], prim::kPrimLabelSet) &&
|
||||
!common::AnfAlgo::HasNodeAttr(kAttrRecursive, kernel_cnodes[i])) {
|
||||
auto label_id = common::AnfAlgo::GetNodeAttr<uint32_t>(kernel_cnodes[i], kAttrLabelIndex);
|
||||
if (label_id_index_map.find(label_id) != label_id_index_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Two labelsets with same label id.";
|
||||
}
|
||||
|
@ -1584,8 +1585,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
|
|||
void AscendSession::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto &node : nodes) {
|
||||
auto status = device::ascend::SelectKernelInfo(node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrPynativeNextOpName, node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrPynativeNextIndex, node);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrPynativeNextOpName, node);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrPynativeNextIndex, node);
|
||||
if (status == device::ascend::kStatusRaisePrecision) {
|
||||
raise_precision_count_++;
|
||||
} else if (status == device::ascend::kStatusReducePrecision) {
|
||||
|
@ -1830,8 +1831,8 @@ void AscendSession::UpdateOutputTensors(const VectorRef *outputs,
|
|||
tensor_device_addr_map_[tensor] = dst_device_address;
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
ShapeVector int_shape;
|
||||
(void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
|
||||
(void)tensor->set_shape(int_shape);
|
||||
|
|
|
@ -21,8 +21,9 @@
|
|||
#include "ir/anf.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
|
@ -79,7 +80,9 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf,
|
|||
}
|
||||
|
||||
// Remove after PS feature finish adapting push/pull in auto_monad.
|
||||
void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPosteriorExecList(NOT_NULL(node_list)); }
|
||||
void CPUSession::Reorder(std::vector<CNodePtr> *node_list) {
|
||||
common::AnfAlgo::ReorderPosteriorExecList(NOT_NULL(node_list));
|
||||
}
|
||||
|
||||
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
@ -178,7 +181,7 @@ void CPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
continue;
|
||||
}
|
||||
auto input_param = input_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::IsParameterWeight(input_param) && !tensor->IsUpdatedByDevice()) {
|
||||
if (common::AnfAlgo::IsParameterWeight(input_param) && !tensor->IsUpdatedByDevice()) {
|
||||
continue;
|
||||
}
|
||||
if (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() !=
|
||||
|
@ -249,10 +252,10 @@ void CPUSession::SetOutputFlags(const VectorRef &base_ref) {
|
|||
|
||||
void CPUSession::UpdateDynamicOutputShape(const std::map<tensor::TensorPtr, KernelWithIndex> &tensor_to_node) {
|
||||
for (const auto &tensor_node : tensor_to_node) {
|
||||
if (AnfAlgo::IsDynamicShape(tensor_node.second.first)) {
|
||||
if (common::AnfAlgo::IsDynamicShape(tensor_node.second.first)) {
|
||||
const auto &kernel = tensor_node.second.first;
|
||||
const auto &output_index = tensor_node.second.second;
|
||||
const auto &shape = AnfAlgo::GetOutputInferShape(kernel, output_index);
|
||||
const auto &shape = common::AnfAlgo::GetOutputInferShape(kernel, output_index);
|
||||
std::vector<int64_t> refresh_shape;
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(refresh_shape));
|
||||
MS_EXCEPTION_IF_NULL(tensor_node.first);
|
||||
|
@ -314,7 +317,7 @@ void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
|
|||
|
||||
namespace {
|
||||
void KernelNotSupportException(const AnfNodePtr &kernel_node) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
std::stringstream operator_info;
|
||||
operator_info << "Operator[" << kernel_name << "] ";
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel_node->kernel_info());
|
||||
|
@ -362,7 +365,7 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) {
|
|||
std::vector<AnfNodePtr> akg_nodes;
|
||||
for (const auto &kernel_node : kernel_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "].";
|
||||
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel_node) == KernelType::AKG_KERNEL) {
|
||||
if (!bin_map->initialized()) {
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#include <exception>
|
||||
#include <set>
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#include "pybind_api/ir/tensor_py.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
|
|
|
@ -31,8 +31,8 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/contract.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "include/common/utils/contract.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/session/executor_manager.h"
|
||||
#include "common/thread_pool.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device_name, uint32_t device_id) {
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
#include "ir/param_info.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/ms_device_shape_transfer.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -44,7 +45,7 @@ void GpuInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kern
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
if (!common::AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor = inputs[no_weight_input++];
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
|
@ -70,7 +71,7 @@ GraphId GpuInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph)
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
if (common::AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
const auto ¶m_value = pk_node->default_param();
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
|
||||
|
@ -100,7 +101,7 @@ bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<
|
|||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
if (!common::AnfAlgo::IsParameterWeight(parameter)) {
|
||||
paras.push_back(parameter);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@
|
|||
#include "backend/common/pass/getitem_tuple.h"
|
||||
#include "backend/common/pass/optimize_updatestate.h"
|
||||
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include "utils/ms_device_shape_transfer.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
|
@ -80,10 +80,10 @@
|
|||
#include "plugin/device/gpu/hal/device/gpu_bucket.h"
|
||||
#include "plugin/device/gpu/hal/device/gpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
#include "ps/util.h"
|
||||
|
@ -313,8 +313,8 @@ size_t UpdateGraphInputAbstract(const AnfNodePtr input_node, const tensor::Tenso
|
|||
auto tensor_shape = tensor->shape();
|
||||
std::vector<size_t> shape_tmp;
|
||||
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), LongToSize);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
return size;
|
||||
|
@ -372,7 +372,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
bool need_sync = CheckIfNeedSync(tensor, device_address, pk_node);
|
||||
if (need_sync) {
|
||||
if (AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) ||
|
||||
if (common::AnfAlgo::IsParameterWeight(pk_node) || UpdatedByAssign(kernel_graph, input_node) ||
|
||||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
tensor->set_device_address(device_address);
|
||||
}
|
||||
|
@ -611,11 +611,11 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs,
|
|||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_mode = ps::PSContext::instance()->is_ps_mode();
|
||||
#endif
|
||||
if (node->isa<CNode>() && !AnfAlgo::IsCommunicationOp(node) && !ps_mode) {
|
||||
if (node->isa<CNode>() && !common::AnfAlgo::IsCommunicationOp(node) && !ps_mode) {
|
||||
auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize());
|
||||
// If a nop node is output, its previous node should be set.
|
||||
if (opt::IsNopNode(node)) {
|
||||
auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 0, true);
|
||||
if (common::AnfAlgo::IsNopNode(node)) {
|
||||
auto pre_node = common::AnfAlgo::GetPrevNodeOutput(node, 0, true);
|
||||
if (!pre_node.first->isa<Parameter>()) {
|
||||
AnfAlgo::SetOutputAddr(new_address, pre_node.second, pre_node.first.get());
|
||||
}
|
||||
|
@ -632,8 +632,8 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs,
|
|||
}
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
ShapeVector int_shape;
|
||||
std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
|
||||
tensor->set_shape(int_shape);
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#include "common/duplex_pipe.h"
|
||||
#include "include/common/duplex_pipe.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
#include "utils/hash_set.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
@ -53,11 +54,11 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
|||
|
||||
std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
|
||||
auto item_with_index =
|
||||
AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
common::AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
AnfNodePtr node = item_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto outputs = AnfAlgo::GetAllOutput(node);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
auto outputs = common::AnfAlgo::GetAllOutput(node);
|
||||
std::set<AnfNodePtr> memo;
|
||||
std::vector<AnfNodePtr> new_output;
|
||||
for (auto &output : outputs) {
|
||||
|
@ -67,11 +68,11 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
|
|||
memo.insert(output);
|
||||
new_output.push_back(output);
|
||||
}
|
||||
if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
|
||||
if (new_output.size() == 1 && common::AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
|
||||
node = new_output[0];
|
||||
}
|
||||
}
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
|
||||
return {node};
|
||||
}
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
|
@ -95,9 +96,9 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
|
|||
if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
|
||||
return false;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
|
||||
return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
|
||||
AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
|
||||
return common::AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
|
||||
common::AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -132,25 +133,25 @@ void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::
|
|||
std::string GetNodeGroup(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
||||
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
||||
return common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
void SetInternalOutputAttr(const AnfNodePtr &node) {
|
||||
if (!opt::IsNopNode(node)) {
|
||||
if (!common::AnfAlgo::IsNopNode(node)) {
|
||||
return;
|
||||
}
|
||||
auto p = GetCNodePrimitive(node);
|
||||
if (p == nullptr) return;
|
||||
auto prim_node = NewValueNode(p->Clone());
|
||||
node->cast<CNodePtr>()->set_input(kAnfPrimitiveIndex, prim_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrIsInternalOutputNopNode, MakeValue(true), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrIsInternalOutputNopNode, MakeValue(true), node);
|
||||
}
|
||||
|
||||
bool NeedOptimize(const AnfNodePtr &node, const std::string &optimized_comm_group) {
|
||||
bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
|
||||
bool is_fused_comm = common::AnfAlgo::IsFusedCommunicationOp(node);
|
||||
if (!is_fused_comm) {
|
||||
return false;
|
||||
}
|
||||
|
@ -219,8 +220,8 @@ void KernelGraph::EnqueueReadyNodes(const AnfNodePtr &node, std::queue<AnfNodePt
|
|||
// allreduce first
|
||||
if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
|
||||
(void)visited_nodes->insert(next_node);
|
||||
bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
|
||||
if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
|
||||
bool is_comm_node = common::AnfAlgo::IsCommunicationOp(next_node);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
|
||||
EnqueueReadyNodes(next_node, visit_queue, visited_nodes);
|
||||
} else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
|
||||
MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
|
||||
|
@ -272,7 +273,7 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
execution_order_.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
// delay execute comm ops that need optimize
|
||||
bool is_comm = AnfAlgo::IsCommunicationOp(node);
|
||||
bool is_comm = common::AnfAlgo::IsCommunicationOp(node);
|
||||
bool optimize_comm = NeedOptimize(node, optimized_comm_group);
|
||||
if (optimize_comm) {
|
||||
optimized_comm_group = GetNodeGroup(node);
|
||||
|
@ -451,11 +452,11 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
void KernelGraph::PostNewCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
if (AnfAlgo::IsGraphKernel(cnode)) {
|
||||
if (common::AnfAlgo::IsGraphKernel(cnode)) {
|
||||
CreateKernelInfoFromNewParameter(cnode);
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||
common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
SetKernelInfoForNode(cnode);
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
|
@ -472,7 +473,7 @@ CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, c
|
|||
}
|
||||
|
||||
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
auto func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
|
@ -486,7 +487,7 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
|||
}
|
||||
auto anf_cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(anf_cnode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(anf_cnode);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(anf_cnode);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = anf_cnode->input(i + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
|
@ -507,14 +508,14 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
|
||||
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
|
||||
if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
|
||||
"flag but got the node :"
|
||||
<< cnode->DebugString();
|
||||
}
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode, 0);
|
||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
|
||||
auto assign_value_node = common::AnfAlgo::GetInputNode(cnode, 1);
|
||||
if (AnfAlgo::IsFeatureMapOutput(input_node)) {
|
||||
return;
|
||||
}
|
||||
|
@ -531,7 +532,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
node->set_kernel_info(kernel_info);
|
||||
if (node->isa<CNode>()) {
|
||||
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
|
||||
if (kOpAssignKernelNameList.find(common::AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
|
||||
ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
|
||||
}
|
||||
#if defined(__APPLE__)
|
||||
|
@ -540,21 +541,21 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
std::vector<size_t> feature_map_input_indexs;
|
||||
#endif
|
||||
kernel_info->set_feature_map_flag(false);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
if (AnfAlgo::IsFeatureMapInput(node, index)) {
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
feature_map_input_indexs.push_back(index);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetInputTensorNum(node) == 0) {
|
||||
if (common::AnfAlgo::GetInputTensorNum(node) == 0) {
|
||||
kernel_info->set_feature_map_flag(true);
|
||||
}
|
||||
if (AnfUtils::IsRealKernel(node)) {
|
||||
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
|
||||
// then the node's output is a feature map output
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
|
||||
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
|
||||
common::AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
|
||||
common::AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -572,9 +573,9 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
if (node->isa<Parameter>()) {
|
||||
auto parameter = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
bool is_weight = AnfAlgo::IsParameterWeight(parameter);
|
||||
bool is_weight = common::AnfAlgo::IsParameterWeight(parameter);
|
||||
kernel_info->set_feature_map_flag(!is_weight);
|
||||
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
|
||||
types.push_back(is_weight ? kTypeUnknown : common::AnfAlgo::GetOutputInferDataType(parameter, 0));
|
||||
}
|
||||
// set parameter initaial device data type
|
||||
kernel_build_info_builder->SetOutputsFormat(formats);
|
||||
|
@ -599,7 +600,7 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) {
|
|||
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
|
||||
if (parameter != nullptr) {
|
||||
new_parameter->set_name(parameter->name());
|
||||
if (AnfAlgo::IsParameterWeight(parameter)) {
|
||||
if (common::AnfAlgo::IsParameterWeight(parameter)) {
|
||||
new_parameter->set_default_param(parameter->default_param());
|
||||
}
|
||||
}
|
||||
|
@ -712,9 +713,9 @@ AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t out
|
|||
AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_scope(node->scope());
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
||||
std::vector<size_t> origin_shape = common::AnfAlgo::GetOutputInferShape(node, output_idx);
|
||||
TypeId origin_type = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
|
@ -723,20 +724,20 @@ AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
|
|||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) {
|
||||
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index));
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
|
||||
shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index));
|
||||
types.push_back(common::AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
|
||||
shapes.emplace_back(common::AnfAlgo::GetOutputInferShape(node, tuple_out_index));
|
||||
}
|
||||
auto make_tuple = NewCNode(std::move(make_tuple_inputs_list));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::IsTupleOutput(node)) {
|
||||
if (!common::AnfAlgo::IsTupleOutput(node)) {
|
||||
return node;
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
|
@ -930,11 +931,11 @@ void KernelGraph::SetOutputNodeToTensor(const KernelMapTensor &node_to_tensor) {
|
|||
for (const auto &item : output_node_to_tensor_) {
|
||||
auto node = item.first.first;
|
||||
auto out_index = item.first.second;
|
||||
if (!opt::IsNopNode(node)) {
|
||||
if (!common::AnfAlgo::IsNopNode(node)) {
|
||||
continue;
|
||||
}
|
||||
while (opt::IsNopNode(node)) {
|
||||
const auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, 0);
|
||||
while (common::AnfAlgo::IsNopNode(node)) {
|
||||
const auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(node, 0);
|
||||
node = kernel_with_index.first;
|
||||
out_index = kernel_with_index.second;
|
||||
}
|
||||
|
@ -1007,7 +1008,7 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
|
|||
std::vector<CNodePtr> result;
|
||||
for (const auto &anf : execution_order_) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
||||
result.push_back(anf->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
|
@ -1019,7 +1020,7 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<Primiti
|
|||
for (const auto &anf : execution_order_) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
for (const auto &primitive : primitive_list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
||||
result.push_back(anf->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
|
@ -1037,17 +1038,19 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|||
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
||||
|
||||
std::string event_str;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
|
||||
event_str = ", event id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
|
||||
event_str =
|
||||
", event id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
|
||||
}
|
||||
|
||||
std::string label_str;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
|
||||
label_str = ", label id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
|
||||
label_str =
|
||||
", label id[" + std::to_string(common::AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
|
||||
}
|
||||
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
|
||||
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
|
||||
auto label_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
|
||||
label_str = ", label id[";
|
||||
for (size_t j = 0; j < label_list.size(); ++j) {
|
||||
label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
|
||||
|
@ -1055,8 +1058,8 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|||
}
|
||||
|
||||
std::string active_stream_str;
|
||||
if (AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
|
||||
auto stream_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
|
||||
auto stream_list = common::AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
|
||||
active_stream_str = ", active stream id[";
|
||||
for (size_t j = 0; j < stream_list.size(); ++j) {
|
||||
active_stream_str += std::to_string(stream_list[j]) + (j + 1 < stream_list.size() ? ", " : "]");
|
||||
|
@ -1064,8 +1067,9 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|||
}
|
||||
|
||||
std::string group_str;
|
||||
if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL && AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
|
||||
group_str = ", group[" + AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
|
||||
if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL &&
|
||||
common::AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
|
||||
group_str = ", group[" + common::AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
|
||||
|
@ -1084,8 +1088,8 @@ void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodeP
|
|||
MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
|
||||
front_to_internal_outputs_map_[front_node] = node;
|
||||
SetInternalOutputAttr(node);
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
|
||||
output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
|
||||
if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
|
||||
output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
|
||||
}
|
||||
internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
|
||||
}
|
||||
|
@ -1191,7 +1195,7 @@ void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter,
|
|||
return;
|
||||
}
|
||||
|
||||
auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
|
||||
auto front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
|
||||
AnfWithOutIndex new_front_node_with_index;
|
||||
if (front_node_with_index.second < front_outputs.size()) {
|
||||
new_front_node_with_index = front_outputs[front_node_with_index.second];
|
||||
|
@ -1234,7 +1238,7 @@ void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNode
|
|||
MS_LOG(INFO) << "Get graph backend output nodes.";
|
||||
std::vector<KernelWithIndex> backend_output_nodes;
|
||||
for (auto &backend_output : backend_outputs) {
|
||||
auto temp_backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_output);
|
||||
auto temp_backend_outputs = common::AnfAlgo::GetAllOutputWithIndex(backend_output);
|
||||
(void)backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.begin(),
|
||||
temp_backend_outputs.end());
|
||||
}
|
||||
|
@ -1242,7 +1246,7 @@ void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNode
|
|||
MS_LOG(INFO) << "Get graph front output nodes.";
|
||||
std::vector<KernelWithIndex> front_output_nodes;
|
||||
for (auto &front_output : front_outputs) {
|
||||
auto temp_front_outputs = AnfAlgo::GetAllOutputWithIndex(front_output);
|
||||
auto temp_front_outputs = common::AnfAlgo::GetAllOutputWithIndex(front_output);
|
||||
(void)front_output_nodes.insert(front_output_nodes.end(), temp_front_outputs.begin(), temp_front_outputs.end());
|
||||
}
|
||||
|
||||
|
@ -1357,7 +1361,7 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
|
|||
|
||||
void KernelGraph::UpdateGraphDynamicAttr() {
|
||||
for (const auto &cnode : execution_order_) {
|
||||
if (AnfAlgo::IsDynamicShape(cnode)) {
|
||||
if (common::AnfAlgo::IsDynamicShape(cnode)) {
|
||||
MS_LOG(INFO) << "Update Graph Dynamic Attr";
|
||||
is_dynamic_shape_ = true;
|
||||
return;
|
||||
|
@ -1369,7 +1373,7 @@ void KernelGraph::UpdateGraphDynamicAttr() {
|
|||
void KernelGraph::SetInputNodes() {
|
||||
input_nodes_.clear();
|
||||
for (const auto &input_node : inputs()) {
|
||||
auto params = AnfAlgo::GetAllOutput(input_node);
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
if (params.size() == 1) {
|
||||
FrontBackendlMapUpdate(input_node, params[0]);
|
||||
} else {
|
||||
|
@ -1390,7 +1394,7 @@ void KernelGraph::SetInputNodes() {
|
|||
|
||||
void KernelGraph::UpdateGraphAquireGilAttr() {
|
||||
for (const auto &cnode : execution_order_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPyFunc)) {
|
||||
MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
|
||||
is_need_gil_ = true;
|
||||
return;
|
||||
|
@ -1402,12 +1406,12 @@ void KernelGraph::SetOptimizerFlag() {
|
|||
has_optimizer_ = false;
|
||||
for (const auto &cnode : execution_order_) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::IsUpdateParameterKernel(cnode)) {
|
||||
if (!common::AnfAlgo::IsUpdateParameterKernel(cnode)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &input : cnode->inputs()) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto real_node = AnfAlgo::VisitKernel(input, 0).first;
|
||||
auto real_node = common::AnfAlgo::VisitKernel(input, 0).first;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
if (!real_node->isa<Parameter>()) {
|
||||
continue;
|
||||
|
@ -1431,7 +1435,7 @@ bool KernelGraph::IsDatasetGraph() const {
|
|||
return false;
|
||||
}
|
||||
for (const auto &node : nodes) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(node);
|
||||
auto node_name = common::AnfAlgo::GetCNodeName(node);
|
||||
if (node_name == prim::kPrimInitDataSetQueue->name()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -1445,7 +1449,7 @@ bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
|
|||
std::vector<AnfNodePtr> child_graph_results;
|
||||
for (const auto &child_graph_result : child_graph_result_) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph_result);
|
||||
auto outputs = AnfAlgo::GetAllOutput(child_graph_result);
|
||||
auto outputs = common::AnfAlgo::GetAllOutput(child_graph_result);
|
||||
(void)child_graph_results.insert(child_graph_results.end(), outputs.begin(), outputs.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/contract.h"
|
||||
#include "include/common/utils/contract.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <utility>
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
|
|
@ -29,9 +29,10 @@
|
|||
#include "kernel/common_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "base/base_ref_utils.h"
|
||||
#include "utils/ms_device_shape_transfer.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "runtime/device/ms_device_shape_transfer.h"
|
||||
#include "include/common/utils/config_manager.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/executor_manager.h"
|
||||
#include "backend/common/optimizer/common_backend_optimization.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
@ -39,12 +40,12 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#include "ps/constants.h"
|
||||
|
@ -90,11 +91,11 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodeP
|
|||
auto node = kernel.first;
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (kernel.second > 1 &&
|
||||
(AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
|
||||
if (kernel.second > 1 && (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
|
||||
return false;
|
||||
}
|
||||
if (AnfUtils::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
if (AnfUtils::IsRealKernel(node) && !common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
return true;
|
||||
}
|
||||
(*idx) += 1;
|
||||
|
@ -201,13 +202,13 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
|||
}
|
||||
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
|
||||
if (type_id == kTypeUnknown) {
|
||||
type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
type_id = common::AnfAlgo::GetOutputInferDataType(node, output_index);
|
||||
}
|
||||
std::vector<int64_t> temp_shape;
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
auto max_shape = AnfAlgo::GetOutputMaxShape(node, output_index);
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
|
||||
temp_shape = abstract::ShapeSize(max_shape) > abstract::ShapeSize(temp_shape) ? max_shape : temp_shape;
|
||||
}
|
||||
tensor::TensorPtr tensor;
|
||||
|
@ -249,11 +250,11 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra
|
|||
MS_EXCEPTION_IF_NULL(tensor_to_node);
|
||||
MS_EXCEPTION_IF_NULL(node_to_tensor);
|
||||
MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
|
||||
// special handle for maketuple
|
||||
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto cnode = item_with_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
|
@ -264,7 +265,7 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra
|
|||
return ret;
|
||||
}
|
||||
// if is graph return nothing ,the function should return a null anylist
|
||||
size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
|
||||
size_t size = common::AnfAlgo::GetOutputTensorNum(item_with_index.first);
|
||||
if (size == 0) {
|
||||
return VectorRef();
|
||||
}
|
||||
|
@ -341,7 +342,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
|
||||
if (NeedDiscardTensorProperties(op_run_info.device_target, device_address)) {
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
|
||||
TypeId param_init_data_type = common::AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
|
||||
} else {
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
|
||||
|
@ -439,11 +440,11 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
|
|||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(output_indexes);
|
||||
MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
|
||||
// special handle for maketuple
|
||||
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
|
||||
auto cnode = item_with_index.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
|
@ -456,7 +457,7 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
|
|||
return ret;
|
||||
}
|
||||
// if is graph return nothing ,the function should return a null anylist
|
||||
size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
|
||||
size_t size = common::AnfAlgo::GetOutputTensorNum(item_with_index.first);
|
||||
if (size == 0) {
|
||||
return VectorRef();
|
||||
}
|
||||
|
@ -466,17 +467,17 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
|
|||
void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
const auto &tensor_shape = tensor->shape();
|
||||
const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
|
||||
const auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
|
||||
if (tensor_shape.size() != input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
|
||||
<< " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
|
||||
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
|
||||
<< "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
|
||||
}
|
||||
for (size_t i = 0; i < tensor_shape.size(); i++) {
|
||||
if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
|
||||
<< " is not equal to expected shape: " << input_shape << " for input[" << input_index
|
||||
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
|
||||
<< "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -497,7 +498,7 @@ void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
constexpr auto kReturnInputIdx = 1;
|
||||
auto return_node = node->cast<CNodePtr>();
|
||||
graph->set_return(return_node);
|
||||
|
@ -507,7 +508,7 @@ void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
|||
// If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
|
||||
// match this pattern because that pass begin with output node but return node. So we add transform value tuple
|
||||
// to make_tuple here.
|
||||
if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
|
||||
if (common::AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
|
||||
return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
|
||||
}
|
||||
}
|
||||
|
@ -535,7 +536,7 @@ void GetNodeUsedList(const FuncGraphPtr &kernel_graph, const AnfNodePtr &node,
|
|||
|
||||
auto node_users = iter->second;
|
||||
for (const auto &node_user : node_users) {
|
||||
if (AnfAlgo::GetCNodeName(node_user.first) == prim::kPrimLoad->name()) {
|
||||
if (common::AnfAlgo::GetCNodeName(node_user.first) == prim::kPrimLoad->name()) {
|
||||
GetNodeUsedList(kernel_graph, node_user.first, node_users_list);
|
||||
} else {
|
||||
node_users_list->push_back(node_user.first);
|
||||
|
@ -611,10 +612,10 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
return;
|
||||
}
|
||||
size_t output_idx = 0;
|
||||
if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
|
||||
if (common::AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
|
||||
}
|
||||
auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
|
||||
auto real_kernel = common::AnfAlgo::VisitKernel(ref_node, output_idx);
|
||||
auto ref_real_node = real_kernel.first;
|
||||
auto ref_real_node_index = real_kernel.second;
|
||||
if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
|
||||
|
@ -623,7 +624,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
MS_LOG(INFO) << "No kernel info";
|
||||
return;
|
||||
}
|
||||
if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
|
||||
if (!common::AnfAlgo::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
|
||||
MS_LOG(INFO) << "No kernel address";
|
||||
return;
|
||||
}
|
||||
|
@ -648,11 +649,11 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
||||
auto parameters = AnfAlgo::GetAllOutput(new_parameter);
|
||||
auto parameters = common::AnfAlgo::GetAllOutput(new_parameter);
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
||||
if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
|
||||
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
||||
pre_graph_out = common::AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
|
@ -673,7 +674,7 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
|
|||
}
|
||||
size_t param_index = 0;
|
||||
for (const auto &out_node : pre_graph_out) {
|
||||
size_t output_size = AnfAlgo::GetOutputTensorNum(out_node);
|
||||
size_t output_size = common::AnfAlgo::GetOutputTensorNum(out_node);
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
if (param_index >= parameters.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
|
||||
|
@ -742,7 +743,7 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
if (IsPrimitiveCNode(anf, prim::kPrimLoad)) {
|
||||
auto input = AnfAlgo::GetInputNode(anf->cast<CNodePtr>(), 0);
|
||||
auto input = common::AnfAlgo::GetInputNode(anf->cast<CNodePtr>(), 0);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<Parameter>()) {
|
||||
auto new_param = CreateNewParameterFromParameter(input, graph);
|
||||
|
@ -760,12 +761,12 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, Kern
|
|||
void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
if (prim != nullptr) {
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
|
||||
} else {
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_fg = BasicClone(fg);
|
||||
cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
|
||||
|
@ -842,7 +843,7 @@ CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// switch input generalizes partial
|
||||
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
|
||||
if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
|
||||
auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
|
||||
return backend_node->cast<CNodePtr>();
|
||||
} else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
|
||||
|
@ -883,7 +884,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|||
for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->inputs().size(); index++) {
|
||||
auto node = switch_cnode->input(index);
|
||||
// there is real input in call, should put it to true and false branch in switch
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
auto partial_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
||||
|
@ -920,7 +921,7 @@ void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
|
|||
// return node is a function
|
||||
std::vector<AnfNodePtr> call_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
||||
auto return_input_cnode = return_input->cast<CNodePtr>();
|
||||
auto partial_inputs = return_input_cnode->inputs();
|
||||
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
|
||||
|
@ -991,7 +992,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
std::vector<AnfNodePtr> new_partial_inputs;
|
||||
KernelGraphPtr partial_kernel_graph;
|
||||
// switch_layer node input is partial cnode
|
||||
if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
|
||||
auto partial_node = partial_idx->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto partial_input = partial_node->input(kFirstDataInputIndex);
|
||||
|
@ -1007,7 +1008,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
auto ret = partial_kernel_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
|
||||
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
|
||||
}
|
||||
// partial node add input args
|
||||
|
@ -1042,7 +1043,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|||
return {};
|
||||
}
|
||||
// if the node is partial, insert the inputs of partial to the call
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
|
||||
auto partial_node = attr_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto partial_inputs = partial_node->inputs();
|
||||
|
@ -1052,9 +1053,9 @@ std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &
|
|||
return graph->GetBackendAnfByFrontAnf(node);
|
||||
});
|
||||
return cnode_inputs;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
||||
return CreateCallSwitchInputs(cnode, graph);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
|
||||
return CreateCallSwitchLayerInputs(cnode, graph);
|
||||
}
|
||||
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
|
||||
|
@ -1068,8 +1069,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, Ker
|
|||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (AnfAlgo::IsGraphKernel(cnode)) {
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
if (common::AnfAlgo::IsGraphKernel(cnode)) {
|
||||
auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_fg = BasicClone(fg);
|
||||
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
|
||||
|
@ -1092,7 +1093,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, Ker
|
|||
void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||
for (size_t index = kSwitchTrueBranchIndex; index < cnode->inputs().size(); index++) {
|
||||
auto node_input = cnode->input(index);
|
||||
|
@ -1135,7 +1136,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
}
|
||||
} else {
|
||||
// get primitive of old node
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
|
||||
|
@ -1148,12 +1149,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
|
|||
if (new_cnode->inputs().size() > 1) {
|
||||
auto first_input = new_cnode->input(kFirstDataInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(first_input);
|
||||
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
|
||||
new_cnode = first_input->cast<CNodePtr>();
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
|
||||
auto abstract = cnode->abstract();
|
||||
new_cnode = first_input->cast<CNodePtr>();
|
||||
new_cnode->set_abstract(abstract);
|
||||
|
@ -1167,7 +1168,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
|
||||
auto sub_func_graph = common::AnfAlgo::GetValueNodeFuncGraph(anf);
|
||||
MS_EXCEPTION_IF_NULL(sub_func_graph);
|
||||
if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
|
||||
|
@ -1293,11 +1294,11 @@ void SessionBasic::SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGrap
|
|||
GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel);
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const AbstractBasePtr &abstract = kernel->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
|
||||
GraphInfo graph_info;
|
||||
// get input tensor info
|
||||
for (const auto &tensor : input_tensors) {
|
||||
|
@ -1328,7 +1329,7 @@ GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
|
|||
MS_EXCEPTION_IF_NULL(build_shape);
|
||||
(void)graph_info.append(build_shape->ToString() + "_");
|
||||
for (size_t output_index = 0; output_index < output_num; output_index += 1) {
|
||||
const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
|
||||
const auto output_type = common::AnfAlgo::GetOutputInferDataType(kernel, output_index);
|
||||
(void)graph_info.append(std::to_string(output_type) + "_");
|
||||
}
|
||||
graph_info.append(std::to_string(prim->id()));
|
||||
|
@ -1339,7 +1340,7 @@ OpRunInfo SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const GraphInf
|
|||
const InputTensorInfo &tensor_info,
|
||||
GraphOutputInfo *const graph_output_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
const auto &abstract = cnode->abstract();
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
|
||||
|
@ -1375,7 +1376,7 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
|
|||
MS_EXCEPTION_IF_NULL(parameter_index);
|
||||
size_t index = 0;
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
auto params = AnfAlgo::GetAllOutput(input_node);
|
||||
auto params = common::AnfAlgo::GetAllOutput(input_node);
|
||||
for (const auto ¶m : params) {
|
||||
if (index >= inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
|
||||
|
@ -1385,7 +1386,7 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
|
|||
MS_EXCEPTION_IF_NULL(input);
|
||||
// Check shape of input and parameter
|
||||
const auto &input_shape = input->shape();
|
||||
const auto ¶m_shape = AnfAlgo::GetOutputInferShape(param, 0);
|
||||
const auto ¶m_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
|
||||
if (input_shape.size() != param_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape
|
||||
<< ") are different, input index: " << index << ", parameter: " << param->DebugString();
|
||||
|
@ -1422,7 +1423,7 @@ void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithInde
|
|||
for (const auto &kernel : graph->execution_order()) {
|
||||
for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
|
||||
const auto &input = kernel->input(i);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
|
||||
const auto &node = kernel_with_index.first;
|
||||
if (node->isa<CNode>()) {
|
||||
(*ref_count)[kernel_with_index] += 1;
|
||||
|
@ -1446,10 +1447,10 @@ void SessionBasic::GetForwardOpOutputRefCount(const KernelGraph *graph, const st
|
|||
const auto &forward_op_output_id = pynative::PynativeExecutor::GetInstance()->grad_executor()->forward_op_output_id();
|
||||
MS_LOG(DEBUG) << "Total forward op out put size " << forward_op_output_id.size();
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
const auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
const auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 1; i <= input_tensor_num; ++i) {
|
||||
const auto &input = kernel->input(i);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (real_input->isa<ValueNode>()) {
|
||||
|
@ -1632,15 +1633,15 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
|
|||
InputTensorInfo *input_tensor_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_info);
|
||||
auto has_const_input_to_attr = AnfAlgo::HasNodeAttr(kAttrNeedConvertToValueNode, cnode);
|
||||
auto has_const_input_to_attr = common::AnfAlgo::HasNodeAttr(kAttrNeedConvertToValueNode, cnode);
|
||||
std::vector<size_t> const_input_attr_index = {};
|
||||
if (has_const_input_to_attr) {
|
||||
const_input_attr_index = AnfAlgo::GetNodeAttr<std::vector<size_t>>(cnode, kAttrNeedConvertToValueNode);
|
||||
const_input_attr_index = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(cnode, kAttrNeedConvertToValueNode);
|
||||
}
|
||||
const auto input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
const auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 1; i <= input_tensor_num; i += 1) {
|
||||
const auto &input = cnode->input(i);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
|
@ -1661,7 +1662,7 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
|
|||
: kParameterDataTensorMask);
|
||||
} else if (real_input->isa<CNode>()) {
|
||||
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
|
||||
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
if (common::AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
CheckInputTensorShape(tensor, cnode, i - 1);
|
||||
}
|
||||
input_tensor_info->input_kernel.insert(kernel_with_index);
|
||||
|
@ -1690,7 +1691,7 @@ tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
|
|||
}
|
||||
|
||||
const auto &input = cnode->input(input_index + 1);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
|
@ -1698,7 +1699,7 @@ tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
|
|||
return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
|
||||
} else if (real_input->isa<CNode>()) {
|
||||
tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
|
||||
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
if (common::AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
CheckInputTensorShape(tensor, cnode, input_index);
|
||||
}
|
||||
input_tensor_info->input_kernel.insert(kernel_with_index);
|
||||
|
@ -1765,7 +1766,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
continue;
|
||||
}
|
||||
// Create child kernel graph according ValueNode<FuncGraph>
|
||||
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
|
||||
FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(node);
|
||||
if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
|
||||
(void)ConstructKernelGraph(child_graph, all_out_graph, device_target);
|
||||
}
|
||||
|
@ -1850,8 +1851,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
tensor->SetNeedWait(false);
|
||||
MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
|
||||
<< ", device address " << tensor->device_address().get();
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
ShapeVector int_shape;
|
||||
(void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
|
||||
(void)tensor->set_shape(int_shape);
|
||||
|
@ -1870,7 +1871,7 @@ void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kern
|
|||
const auto &kernels = kernel_graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
|
||||
op_run_info->abstract = kernel->abstract();
|
||||
}
|
||||
}
|
||||
|
@ -1945,8 +1946,8 @@ void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
|
|||
const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
tensor->set_device_address(address);
|
||||
|
||||
if (AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
|
||||
ShapeVector int_shape;
|
||||
(void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
|
||||
(void)tensor->set_shape(int_shape);
|
||||
|
@ -1976,7 +1977,7 @@ void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::Ten
|
|||
continue;
|
||||
}
|
||||
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
|
||||
if (!AnfAlgo::IsParameterWeight(parameter)) {
|
||||
if (!common::AnfAlgo::IsParameterWeight(parameter)) {
|
||||
vector<int64_t> input_shape;
|
||||
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
|
||||
(void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
|
||||
|
@ -2042,7 +2043,7 @@ void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
|||
}
|
||||
auto node = cnode->input(kSummaryGetItem);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false);
|
||||
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (!AnfUtils::IsRealKernel(item_with_index.first)) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
|
||||
|
@ -2078,8 +2079,8 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
auto node = output_item.second.first;
|
||||
size_t index = IntToSize(output_item.second.second);
|
||||
auto address = AnfAlgo::GetOutputAddr(node, index, false);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(node, index);
|
||||
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(node, index);
|
||||
TypeId type_id = common::AnfAlgo::GetOutputInferDataType(node, index);
|
||||
std::vector<int64_t> temp_shape;
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
|
@ -2121,8 +2122,8 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
|
|||
auto &users = front_func_graph_manager->node_users()[front_node];
|
||||
std::vector<AnfNodePtr> result;
|
||||
for (auto &user : users) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
|
||||
auto depend_cnode = user.first->cast<CNodePtr>();
|
||||
if (depend_cnode == nullptr) {
|
||||
continue;
|
||||
|
@ -2132,7 +2133,7 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
|
|||
}
|
||||
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
|
||||
result.insert(result.end(), res.begin(), res.end());
|
||||
} else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
|
||||
auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
|
||||
(void)result.insert(result.end(), res.begin(), res.end());
|
||||
} else {
|
||||
|
@ -2150,10 +2151,10 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
|
|||
if (AnfUtils::IsRealKernel(front_node)) {
|
||||
return front_node;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
|
||||
return front_node;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
|
||||
auto cnode = front_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -2161,7 +2162,7 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
|
|||
return GetSupportedInternalNode(inputs[1]);
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
|
||||
auto cnode = front_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -2239,8 +2240,8 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
if (front_node == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
|
||||
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
|
||||
auto front_real_kernel_pair = common::AnfAlgo::VisitKernel(front_node, 0);
|
||||
auto backend_real_kernel_pair = common::AnfAlgo::VisitKernel(backend_node, 0);
|
||||
auto backend_real_kernel = backend_real_kernel_pair.first;
|
||||
if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
|
||||
return;
|
||||
|
@ -2249,8 +2250,8 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
std::string kernel_target = GetCNodeTarget(front_real_kernel);
|
||||
bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
|
||||
bool unique_target = true;
|
||||
if (internal_output && opt::IsNopNode(front_real_kernel)) {
|
||||
auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
|
||||
if (internal_output && common::AnfAlgo::IsNopNode(front_real_kernel)) {
|
||||
auto pre_node_pair = common::AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
|
||||
auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
|
||||
if (pre_node_target != kernel_target) {
|
||||
unique_target = false;
|
||||
|
@ -2259,7 +2260,7 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
if (internal_output) {
|
||||
auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
|
||||
for (auto &user : users) {
|
||||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
|
||||
if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
|
||||
!ExistGraphCaller(user)) {
|
||||
auto partial_target = AddPartialParametersMap(user);
|
||||
if (partial_target != kNoTarget && partial_target != kernel_target) {
|
||||
|
@ -2267,7 +2268,7 @@ void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, cons
|
|||
}
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
if (IsUnusedInternlOutput(user)) {
|
||||
|
@ -2324,16 +2325,16 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
|
|||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
|
||||
for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
|
||||
if (common::AnfAlgo::GetOutputTensorNum(cnode) > 1) {
|
||||
for (size_t output_index = 0; output_index < common::AnfAlgo::GetOutputTensorNum(cnode); output_index++) {
|
||||
auto idx = NewValueNode(SizeToLong(output_index));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(output_index);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||
std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
|
||||
std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
|
||||
std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(cnode, output_index)};
|
||||
std::vector<std::vector<size_t>> shapes = {common::AnfAlgo::GetOutputInferShape(cnode, output_index)};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
|
||||
make_tuple_inputs.push_back(getitem);
|
||||
}
|
||||
} else {
|
||||
|
@ -2380,10 +2381,10 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
|
|||
// set abstract,which include inferred shapes and types
|
||||
cnode->set_abstract(op_run_info.abstract);
|
||||
// get output dynamic shape info
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
|
||||
if (op_run_info.is_auto_mixed_precision) {
|
||||
AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
|
||||
AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
|
||||
}
|
||||
// set execution order
|
||||
std::vector<CNodePtr> exe_order = {cnode};
|
||||
|
@ -2420,8 +2421,8 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve
|
|||
for (auto &node : node_list) {
|
||||
if (node != nullptr && node->isa<CNode>()) {
|
||||
for (auto input : node->cast<CNodePtr>()->inputs()) {
|
||||
if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
|
||||
if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
|
||||
if (push_node == common::AnfAlgo::VisitKernel(input, 0).first) {
|
||||
if (common::AnfAlgo::GetCNodeName(node) != kPullOpName) {
|
||||
MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
|
||||
}
|
||||
return node;
|
||||
|
@ -2583,9 +2584,9 @@ void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask
|
|||
bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
for (const auto &kernel_node : kernel_graph->execution_order()) {
|
||||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == kGetNextOpName) {
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
*channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
|
||||
return true;
|
||||
|
@ -2622,7 +2623,7 @@ std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
|
|||
}
|
||||
|
||||
uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
|
||||
return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
|
||||
return common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
|
||||
}
|
||||
|
||||
void SetGraphBpropAttr(const KernelGraphPtr &graph) {
|
||||
|
@ -2699,7 +2700,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::Devi
|
|||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
|
||||
if (!pynative_mode || parallel_mode != parallel::kDataParallel) {
|
||||
return;
|
||||
}
|
||||
SetGraphBpropAttr(graph);
|
||||
|
@ -2741,7 +2742,7 @@ void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vecto
|
|||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
if (parallel_mode != parallel::DATA_PARALLEL) {
|
||||
if (parallel_mode != parallel::kDataParallel) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2873,7 +2874,7 @@ void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
|||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
auto context = runtime_instance->context();
|
||||
const auto &kernels = kernel_graph->execution_order();
|
||||
if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
|
||||
if (kernels.size() > 0 && common::AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
|
||||
GetBatchElements(kernels[0]);
|
||||
ps::ps_cache_instance.Initialize();
|
||||
}
|
||||
|
@ -2886,8 +2887,8 @@ void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
|
|||
}
|
||||
|
||||
void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
|
||||
auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
|
||||
auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
|
||||
auto shapes = common::AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
|
||||
auto types = common::AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
|
||||
if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
|
||||
<< types;
|
||||
|
@ -2933,12 +2934,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|||
if (node != nullptr && node->isa<CNode>()) {
|
||||
// Assign key for forward kernel EmbeddingLookup.
|
||||
// The key will be assigned to embedding table ande Push kernel as well.
|
||||
if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
if (common::AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
|
||||
size_t embedding_table_idx = 0;
|
||||
auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
|
||||
auto embedding_table = common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
|
||||
size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
||||
} else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
||||
} else if (common::AnfAlgo::GetCNodeName(node) == kPushOpName) {
|
||||
auto pull_node = FindPullNode(node, node_list);
|
||||
if (!pull_node) {
|
||||
MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
|
||||
|
@ -2946,12 +2947,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
|||
|
||||
// Second input of Pull node is the trainable parameter.
|
||||
size_t parameter_index = 1;
|
||||
auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
|
||||
auto parameter_node = common::AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
|
||||
size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
||||
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
|
||||
|
||||
std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
|
||||
std::string optimizer_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
|
||||
ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,10 +26,11 @@
|
|||
#include "backend/common/session/session_context.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/any.h"
|
||||
#include "utils/contract.h"
|
||||
#include "include/common/utils/contract.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/device/bucket.h"
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "backend/common/session/single_kernel_graph.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -44,14 +44,14 @@ std::shared_ptr<session::KernelGraph> SingleKernelGraph::ConstructKernelGraphBas
|
|||
auto cnode = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// get output dynamic shape info
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(false), cnode);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(false), cnode);
|
||||
if (output_dtypes.size() != output_shapes.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "The size of output_dtypes should be equal to size of output_shapes, but got output_dtypes size: "
|
||||
<< output_dtypes.size() << ", output_shapes size: " << output_shapes.size() << ". The op name is: " << op_name
|
||||
<< trace::DumpSourceLines(cnode);
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(output_dtypes, output_shapes, cnode.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(output_dtypes, output_shapes, cnode.get());
|
||||
// set execution order
|
||||
std::vector<CNodePtr> exe_order = {cnode};
|
||||
graph->set_execution_order(exe_order);
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
#ifdef ENABLE_DUMP_IR
|
||||
#include "debug/rdr/running_data_recorder.h"
|
||||
#endif
|
||||
#include "common/thread_pool.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "profiler/device/ascend/memory_profiling.h"
|
||||
|
||||
|
@ -409,7 +409,7 @@ void Somas::InitSomasStreamAndNode(const session::KernelGraph *graph) {
|
|||
|
||||
// Node
|
||||
NodeType type = kCommonNode;
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
if (common::AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
type = kCommunicationNode;
|
||||
}
|
||||
auto node = std::make_shared<SomasNode>(node_index, type, stream);
|
||||
|
@ -500,7 +500,7 @@ void Somas::InitSomasInputTensors(const session::KernelGraph *graph) {
|
|||
static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1");
|
||||
auto kernel_cnodes = graph->execution_order();
|
||||
for (const auto &kernel : kernel_cnodes) {
|
||||
if (AnfAlgo::GetCNodeName(kernel) != kAtomicAddrCleanOpName) {
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != kAtomicAddrCleanOpName) {
|
||||
InitCommonNodeInputs(is_all_nop_node, kernel);
|
||||
} else {
|
||||
InitAtomicCleanInputs(enable_fusion_clear, kernel);
|
||||
|
@ -516,24 +516,24 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
|
|||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
// Input Tensor
|
||||
auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
size_t real_input_index = 0;
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
auto input_node = kernel->input(i + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
session::KernelWithIndex prenode_index;
|
||||
if (is_all_nop_node) {
|
||||
prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
prenode_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
|
||||
} else {
|
||||
prenode_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
prenode_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(prenode_index.first, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(prenode_index.first, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prenode_index.first);
|
||||
if (!AnfUtils::IsRealCNodeKernel(prenode_index.first)) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel);
|
||||
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel, i);
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(kernel);
|
||||
TypeId input_origin_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel, i);
|
||||
if ((op_name == kDynamicRNNOpName || op_name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
|
||||
continue;
|
||||
}
|
||||
|
@ -588,7 +588,7 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern
|
|||
auto stream = node->GetStream();
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
|
||||
auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 0; i < input_tensor_num; i++) {
|
||||
MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]);
|
||||
auto pre_node = kernel->input(i + 1)->cast<CNodePtr>();
|
||||
|
@ -600,8 +600,8 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern
|
|||
auto pre_somas_node = iter->second.at(0);
|
||||
MS_EXCEPTION_IF_NULL(pre_somas_node);
|
||||
// set clean output tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
|
||||
auto clean_output_indexs = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
|
||||
for (auto index : clean_output_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
|
@ -618,8 +618,9 @@ void Somas::InitAtomicCleanInputs(bool enable_fusion_clear, const CNodePtr &kern
|
|||
}
|
||||
}
|
||||
// set clean workspace tensors
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
|
||||
auto clean_workspace_indexs =
|
||||
common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
|
||||
for (const auto &index : clean_workspace_indexs) {
|
||||
if (index > pre_somas_node->output_tensors_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope()
|
||||
|
@ -645,7 +646,7 @@ void Somas::InitSomasEventInfos() {
|
|||
send_recv_map = device::ascend::AscendStreamAssign::GetInstance().get_event_map();
|
||||
#endif
|
||||
for (auto &send_recv : send_recv_map) {
|
||||
size_t event_id = AnfAlgo::GetNodeAttr<uint32_t>(send_recv.first, kAttrEventId);
|
||||
size_t event_id = common::AnfAlgo::GetNodeAttr<uint32_t>(send_recv.first, kAttrEventId);
|
||||
event_map_[event_id] = std::make_pair(send_recv.first, send_recv.second);
|
||||
}
|
||||
|
||||
|
@ -748,7 +749,7 @@ void Somas::GetNextOutputProcess(const session::KernelGraph *graph) {
|
|||
auto kernel_cnodes = graph->execution_order();
|
||||
size_t total_size = 0;
|
||||
for (const auto &kernel : kernel_cnodes) {
|
||||
if (AnfAlgo::GetCNodeName(kernel) != kGetNextOpName) {
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) != kGetNextOpName) {
|
||||
continue;
|
||||
}
|
||||
auto iter = nodes_map_.find(kernel.get());
|
||||
|
@ -809,7 +810,7 @@ void Somas::SummaryInputProcess(const session::KernelGraph *graph) {
|
|||
for (auto &node_item : summary_nodes) {
|
||||
auto origin_node = node_item.second.first;
|
||||
size_t origin_index = IntToSize(node_item.second.second);
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(origin_node, origin_index, true);
|
||||
auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(origin_node, origin_index, true);
|
||||
auto node = item_with_index.first;
|
||||
size_t index = item_with_index.second;
|
||||
auto iter = nodes_map_.find(node.get());
|
||||
|
@ -895,8 +896,8 @@ void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_cnodes = graph->execution_order();
|
||||
for (const auto &kernel : kernel_cnodes) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(kernel);
|
||||
if (AnfAlgo::IsNonTaskOp(kernel)) {
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(kernel);
|
||||
if (common::AnfAlgo::IsNonTaskOp(kernel)) {
|
||||
std::vector<size_t> refnode_input_output;
|
||||
auto node = nodes_map_[kernel.get()].at(0);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1815,7 +1816,7 @@ uint8_t *Somas::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
|
|||
auto output_tensor = somas_node->output_tensors_[index];
|
||||
ptr = mem_base_addr_ + output_tensor->offset_;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map";
|
||||
MS_LOG(EXCEPTION) << "node [" << common::AnfAlgo::GetCNodeName(node) << "] don't exist in nodes_map";
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "backend/common/somas/somas_stream.h"
|
||||
#include "backend/common/somas/somas_parameter.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "common/thread_pool.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
#include "backend/common/somas/somas_solver_core.h"
|
||||
#include "backend/common/somas/somas_solver_pre.h"
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "include/common/utils/parallel_context.h"
|
||||
#include "backend/graph_compiler/transform.h"
|
||||
#include "backend/common/session/session_factory.h"
|
||||
#include "runtime/op_builder/op_lazy_builder.h"
|
||||
|
@ -30,15 +30,15 @@
|
|||
#include "ir/anf.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "utils/callbacks.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "include/common/utils/callbacks.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "utils/callbacks_ge.h"
|
||||
#include "include/common/utils/callbacks_ge.h"
|
||||
#endif
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
#include "debug/debugger/debugger.h"
|
||||
|
@ -196,7 +196,7 @@ void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_
|
|||
const auto &kernels = kernel_graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
|
||||
op_run_info->abstract = kernel->abstract();
|
||||
}
|
||||
}
|
||||
|
@ -206,9 +206,9 @@ TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index)
|
|||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
// Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
|
||||
// when infer type is not equal to device type.
|
||||
auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
|
||||
auto type_id = common::AnfAlgo::GetOutputInferDataType(output_node, output_index);
|
||||
std::vector<int64_t> temp_shape;
|
||||
const auto &shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
|
||||
const auto &shape = common::AnfAlgo::GetOutputInferShape(output_node, output_index);
|
||||
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
|
||||
auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
|
||||
|
@ -263,7 +263,7 @@ void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
for (const auto &node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
if (node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
AnfAlgo::SetOutputAddr(nullptr, 0, node.get());
|
||||
}
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
|
|||
}
|
||||
|
||||
const auto &node = graph->execution_order().back();
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(node);
|
||||
// No value node in graph
|
||||
if (input_num == tensors_without_value_node.size()) {
|
||||
return new_input_tensors;
|
||||
|
@ -287,7 +287,7 @@ std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
|
|||
|
||||
std::map<size_t, tensor::TensorPtr> value_node_pos;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input = AnfAlgo::GetInputNode(node, i);
|
||||
auto input = common::AnfAlgo::GetInputNode(node, i);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<ValueNode>()) {
|
||||
auto value_node = input->cast<ValueNodePtr>();
|
||||
|
@ -453,7 +453,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
|
||||
real_execution_mode_ = ms_execution_mode_;
|
||||
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
auto is_parallel = (parallel_mode == parallel::SEMI_AUTO_PARALLEL || parallel_mode == parallel::AUTO_PARALLEL);
|
||||
auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
|
||||
|
||||
// Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph.
|
||||
if (ms_execution_mode_ == kPynativeMode &&
|
||||
|
@ -573,8 +573,8 @@ void MindRTBackend::CompileGraph(const GraphSegmentPtr &segment) {
|
|||
MS_EXCEPTION_IF_NULL(cut_node);
|
||||
MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
|
||||
control_nodes_.push_back(cut_node);
|
||||
if (AnfAlgo::IsCallNode(cut_node) || AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
|
||||
if (common::AnfAlgo::IsCallNode(cut_node) || common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
|
||||
const auto &func_graph = cut_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
(void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
|
||||
|
@ -629,7 +629,7 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
|
|||
continue;
|
||||
}
|
||||
// Hook single-input or single-output.
|
||||
auto real_input = AnfAlgo::VisitKernel(input_node, 0).first;
|
||||
auto real_input = common::AnfAlgo::VisitKernel(input_node, 0).first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (!real_input->isa<ValueNode>()) {
|
||||
auto tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
|
||||
|
@ -870,7 +870,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
for (const auto &kernel : graph->execution_order()) {
|
||||
InputTensorInfo input_tensor_info;
|
||||
VectorRef op_outputs;
|
||||
if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||
if (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||
OpRunInfo op_run_info;
|
||||
GraphInfo graph_info;
|
||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||
|
@ -893,7 +893,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
|
||||
if (graph->is_bprop() && (!common::AnfAlgo::IsControlOpExecInBackend(kernel)) && !kernel->is_parallel()) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
|
@ -1002,7 +1002,7 @@ BaseRef MindRTBackend::ConstructOutputByAbstract(const abstract::AbstractBasePtr
|
|||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
MS_EXCEPTION_IF_NULL(output_position);
|
||||
|
||||
size_t outputs_num = AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
size_t outputs_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
|
||||
if (*output_position + outputs_num > output_tensors.size()) {
|
||||
MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position << " need:" << outputs_num
|
||||
<< " total:" << output_tensors.size();
|
||||
|
@ -1070,14 +1070,14 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
|
|||
}
|
||||
|
||||
// The depend node need get the real node.
|
||||
if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
|
||||
auto depend_node = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node);
|
||||
auto outputs_num = common::AnfAlgo::GetOutputTensorNum(output_node);
|
||||
// The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
|
||||
if (output_node->isa<ValueNode>()) {
|
||||
auto value = output_node->cast<ValueNodePtr>()->value();
|
||||
|
@ -1093,7 +1093,7 @@ void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
|
|||
return;
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsCallNode(output_node)) {
|
||||
if (common::AnfAlgo::IsCallNode(output_node)) {
|
||||
auto abstract = output_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
outputs->emplace_back(ConstructOutputByAbstract(abstract, output_tensors, output_position));
|
||||
|
@ -1176,9 +1176,9 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(con
|
|||
|
||||
runtime::KernelMapPosition outputs_order;
|
||||
const auto &root_output =
|
||||
AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
common::AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
|
||||
size_t position = 0;
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(root_output);
|
||||
size_t outputs_num = outputs.size();
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
|
@ -1209,7 +1209,7 @@ std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
|
|||
(void)graphs.emplace_back(graph);
|
||||
(void)device_contexts.emplace_back(graph_info_to_context.second);
|
||||
|
||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
auto outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
for (const auto &output : outputs) {
|
||||
if (outputs_order.count(output) == 0) {
|
||||
outputs_order[output] = {position++};
|
||||
|
@ -1276,7 +1276,7 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph, const OpRunInf
|
|||
const auto &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (kOpCacheBlackList.find(AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
|
||||
if (kOpCacheBlackList.find(common::AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
if (kernel_mod) {
|
||||
kernel_mod->ReleaseResource();
|
||||
|
@ -1485,7 +1485,7 @@ void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &ou
|
|||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto &item_with_index : output_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
|
||||
if (common::AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
|
||||
continue;
|
||||
}
|
||||
auto output_tensor = CreateOutputTensor(item_with_index.first, item_with_index.second);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/contract.h"
|
||||
#include "include/common/utils/contract.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/graph_compiler/segment_runner.h"
|
||||
#include "backend/graph_compiler/graph_partition.h"
|
||||
|
|
|
@ -24,12 +24,12 @@
|
|||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "transform/graph_ir/convert.h"
|
||||
#include "include/transform/graph_ir/convert.h"
|
||||
#endif
|
||||
namespace mindspore {
|
||||
const char kMsConvert[] = "ms";
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "abstract/abstract_value.h"
|
||||
#ifdef ENABLE_D
|
||||
#include "transform/graph_ir/convert.h"
|
||||
#include "include/transform/graph_ir/convert.h"
|
||||
#endif
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "ir/manager.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/primitive_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace compile {
|
||||
|
|
|
@ -10,17 +10,6 @@ else()
|
|||
)
|
||||
endif()
|
||||
|
||||
if(ENABLE_AKG AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
file(GLOB_RECURSE _GK_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"graph_kernel/*.cc"
|
||||
)
|
||||
file(GLOB_RECURSE _GK_LITE_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"graph_kernel/lite_adapter/*.cc"
|
||||
)
|
||||
list(REMOVE_ITEM _GK_SRC_LIST ${_GK_LITE_LIST})
|
||||
list(APPEND _COMMON_ALL_SRC_FILES ${_GK_SRC_LIST})
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${_COMMON_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS
|
||||
SUBMODULE_ID=mindspore::SubModuleId::SM_COMMON)
|
||||
add_library(_mindspore_common_obj OBJECT ${_COMMON_ALL_SRC_FILES})
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/duplex_pipe.h"
|
||||
#include "include/common/duplex_pipe.h"
|
||||
|
||||
#include <sys/wait.h>
|
||||
#include <iostream>
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/duplex_pipe.h"
|
||||
#include "include/common/duplex_pipe.h"
|
||||
|
||||
namespace mindspore {
|
||||
int DuplexPipe::Open(const std::initializer_list<std::string> &arg_list, bool append_fds) {
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
if(ENABLE_AKG AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux")
|
||||
file(GLOB_RECURSE _GK_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"*.cc"
|
||||
)
|
||||
file(GLOB_RECURSE _GK_LITE_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"lite_adapter/*.cc"
|
||||
)
|
||||
list(REMOVE_ITEM _GK_SRC_LIST ${_GK_LITE_LIST})
|
||||
list(APPEND _GRAPH_KERNEL_SRC_FILES ${_GK_SRC_LIST})
|
||||
|
||||
add_library(_mindspore_common_graph_kernel_obj OBJECT ${_GRAPH_KERNEL_SRC_FILES})
|
||||
endif()
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/graph_kernel/adapter/fake_abstract_shape.h"
|
||||
#if ENABLE_D
|
||||
|
@ -50,14 +51,14 @@ ShapeVector CallbackImpl::GetOutputShape(const AnfNodePtr &node, size_t i) {
|
|||
}
|
||||
|
||||
ShapeVector CallbackImpl::GetInputInferShape(const AnfNodePtr &node, size_t i) {
|
||||
auto vec = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
|
||||
auto vec = common::AnfAlgo::GetPrevNodeOutputInferShape(node, i);
|
||||
ShapeVector ret;
|
||||
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
|
||||
return ret;
|
||||
}
|
||||
|
||||
ShapeVector CallbackImpl::GetOutputInferShape(const AnfNodePtr &node, size_t i) {
|
||||
auto vec = AnfAlgo::GetOutputInferShape(node, i);
|
||||
auto vec = common::AnfAlgo::GetOutputInferShape(node, i);
|
||||
ShapeVector ret;
|
||||
(void)std::transform(vec.begin(), vec.end(), std::back_inserter(ret), SizeToLong);
|
||||
return ret;
|
||||
|
@ -70,11 +71,11 @@ TypeId CallbackImpl::GetOutputType(const AnfNodePtr &node, size_t i) {
|
|||
}
|
||||
|
||||
TypeId CallbackImpl::GetInputInferType(const AnfNodePtr &node, size_t i) {
|
||||
return AnfAlgo::GetPrevNodeOutputInferDataType(node, i);
|
||||
return common::AnfAlgo::GetPrevNodeOutputInferDataType(node, i);
|
||||
}
|
||||
|
||||
TypeId CallbackImpl::GetOutputInferType(const AnfNodePtr &node, size_t i) {
|
||||
return AnfAlgo::GetOutputInferDataType(node, i);
|
||||
return common::AnfAlgo::GetOutputInferDataType(node, i);
|
||||
}
|
||||
|
||||
std::string CallbackImpl::GetInputFormat(const AnfNodePtr &node, size_t i) { return AnfAlgo::GetInputFormat(node, i); }
|
||||
|
@ -135,7 +136,7 @@ void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
|
|||
outputs.push_back(fg->output());
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(outputs[i], 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(outputs[i], 0);
|
||||
graph_output_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
graph_output_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
}
|
||||
|
@ -160,7 +161,7 @@ void CallbackImpl::SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vec
|
|||
if (cnode != nullptr) {
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
input_formats.push_back(input_format);
|
||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
#include "common/graph_kernel/adapter/fake_abstract_shape.h"
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
|
|
@ -22,14 +22,15 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/split_umonad.h"
|
||||
#include "common/graph_kernel/substitute_dropout.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "mindspore/core/ir/graph_utils.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "pybind_api/ir/primitive_py.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -58,7 +59,7 @@ FuncGraphPtr PyExpander::CreateExpandFuncGraph(const CNodePtr &node) {
|
|||
|
||||
// call graph kernel ops generator.
|
||||
MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
|
||||
auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
|
||||
// parse result.
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
|
||||
|
@ -151,7 +152,7 @@ ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &) {
|
|||
bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
|
||||
(*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
|
||||
(*kernel_json)["name"] = std::string("C") + common::AnfAlgo::GetCNodeName(cnode);
|
||||
return true;
|
||||
}
|
||||
bool GraphKernelComplexExpander::Run(const FuncGraphPtr &func_graph) { return DoExpand(func_graph); }
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/add_atomic_clean.h"
|
||||
#include "common/graph_kernel/add_stitch_atomic_clean_gpu.h"
|
||||
#include "common/graph_kernel/arithmetic_simplify.h"
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
|
|
@ -23,12 +23,11 @@
|
|||
#include <map>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class CostModelSplitSchemer : public SplitSchemer {
|
||||
|
@ -71,7 +70,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
auto flags_str = CollectSplitFlags();
|
||||
MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json: " << json_desc_str
|
||||
<< ". flag: " << flags_str;
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
|
||||
auto ret = python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||
<< json_desc_str << ". flag: " << flags_str;
|
||||
|
@ -180,7 +179,7 @@ class CostModelSplitSchemer : public SplitSchemer {
|
|||
return;
|
||||
}
|
||||
// assign the make_tuple node to a new group.
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
|
||||
auto group_id = split_plan_.size();
|
||||
split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
|
||||
need_inline_.emplace_back(1);
|
||||
|
|
|
@ -26,12 +26,11 @@
|
|||
#include <vector>
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
@ -41,7 +40,7 @@ namespace {
|
|||
auto constexpr NUMBER_COND_FOR_FILTER_INPLACE = 2;
|
||||
std::set<int64_t> GetUniqReduceAxes(const AnfNodePtr &node, bool is_ascend = false) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimReduceSum)) {
|
||||
MS_LOG(EXCEPTION) << "Expect ReduceSum node, but got " << AnfAlgo::GetCNodeName(node);
|
||||
MS_LOG(EXCEPTION) << "Expect ReduceSum node, but got " << common::AnfAlgo::GetCNodeName(node);
|
||||
}
|
||||
|
||||
auto input = node->cast<CNodePtr>()->input(kFirstDataInputIndex);
|
||||
|
@ -110,7 +109,7 @@ bool AtomicAddChecker::FindCandidate(const AnfNodePtr &anf_node) {
|
|||
atomic_add_infos_.clear();
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
|
@ -176,7 +175,7 @@ bool AtomicAddChecker::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
|
|||
}
|
||||
|
||||
bool AtomicAddChecker::Check(const AnfNodePtr &node) {
|
||||
return (AnfAlgo::IsGraphKernel(node) && CanActivateAtomicAdd(node));
|
||||
return (common::AnfAlgo::IsGraphKernel(node) && CanActivateAtomicAdd(node));
|
||||
}
|
||||
|
||||
bool AtomicAddCheckerGPU::SuitableForAtomicAdd(const AnfNodePtr &node) {
|
||||
|
@ -277,7 +276,7 @@ void AtomicCleanInsertter::CorrectKernelBuildInfo(
|
|||
|
||||
for (const auto &clean_info : clean_infos) {
|
||||
auto &new_input = clean_info.second;
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_input, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_input, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
}
|
||||
|
@ -390,7 +389,7 @@ void AtomicCleanInsertter::CorrectAbstract(
|
|||
void AtomicCleanInsertter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
|
|
|
@ -20,13 +20,12 @@
|
|||
#include <string>
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -49,7 +48,7 @@ void StitchAtomicCleanInsertter::CorrectKernelBuildInfo(
|
|||
new_outputs_type.push_back(origin_outputs_type[i]);
|
||||
}
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(clean_infos[0].second, 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(clean_infos[0].second, 0);
|
||||
new_inputs_format.push_back(AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second));
|
||||
new_inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second));
|
||||
|
||||
|
@ -88,7 +87,7 @@ CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr
|
|||
CreateCNode({NewValueNode(prim::kPrimInplaceAssign), new_parameter, out_node, out_node}, sub_graph,
|
||||
{.format = GetFormat(out_node), .shape = GetShape(out_node), .type = GetType(out_node)});
|
||||
SetNodeAttrSafely("fake_output", MakeValue(true), inplace_assign_node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
|
||||
common::AnfAlgo::EraseNodeAttr(kAttrStitch, out_node);
|
||||
SetNodeAttrSafely(kAttrStitch, MakeValue("common"), inplace_assign_node);
|
||||
return inplace_assign_node;
|
||||
}
|
||||
|
@ -96,7 +95,7 @@ CNodePtr StitchAtomicCleanInsertter::CreateInplaceAssignNode(const FuncGraphPtr
|
|||
void StitchAtomicCleanInsertter::ProcessOriginCNode(
|
||||
const AnfNodePtr &composite_node,
|
||||
const std::vector<std::pair<AtomicAddInfo, AnfNodePtr>> &info_and_broadcast_to_nodes) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(composite_node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
|
@ -147,7 +146,7 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
|
|||
const CNodePtr &target) const {
|
||||
auto node = inner_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto mng_sub = sub_graph->manager();
|
||||
if (mng_sub == nullptr) {
|
||||
mng_sub = Manage(sub_graph, false);
|
||||
|
@ -161,15 +160,16 @@ std::vector<std::pair<AnfNodePtr, int>> StitchAtomicCleanInsertter::FindInnerCNo
|
|||
}
|
||||
|
||||
std::pair<bool, AtomicAddInfo> StitchAtomicCleanInsertter::IsStitchWithAtomic(const AnfNodePtr &anf_node) {
|
||||
if (!AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
|
||||
if (!common::AnfAlgo::IsGraphKernel(anf_node)) return {false, AtomicAddInfo()};
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
AnfNodePtrList kernel_nodes;
|
||||
kernel::GetValidKernelNodes(sub_graph, &kernel_nodes);
|
||||
for (auto &n : kernel_nodes) {
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
|
||||
AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" && IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrStitch, n->cast<CNodePtr>()) &&
|
||||
common::AnfAlgo::GetNodeAttr<std::string>(n, kAttrStitch) == "atomic" &&
|
||||
IsPrimitiveCNode(n, prim::kPrimReduceSum)) {
|
||||
MS_LOG(INFO) << "GOT STITCH WITH ATOMIC!!!";
|
||||
AtomicAddInfo info;
|
||||
info.atomic_add_node = n->cast<CNodePtr>();
|
||||
|
|
|
@ -28,8 +28,9 @@
|
|||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
// operator which follows commutative rules
|
||||
|
@ -634,8 +635,8 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
|||
bool do_simplify = false;
|
||||
expressions_map_ = GetExpressions();
|
||||
for (auto node : func_graph->GetOrderedCnodes()) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
if (common::AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
inner::LiteGraphPtr lg = GkUtils::AnfGraph2LiteGraph(sub_graph);
|
||||
bool find_pattern = true;
|
||||
bool change_anf_graph = false;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "ir/scalar.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
int64_t AxisNormalizer::NormAxis(int64_t x, size_t rank) const { return x >= 0 ? x : x + static_cast<int64_t>(rank); }
|
||||
|
@ -83,8 +84,8 @@ bool AxisNormalizer::Run(const FuncGraphPtr &func_graph) {
|
|||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
if (common::AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
changed = Process(sub_func_graph) || changed;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
|
@ -98,8 +99,8 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
auto changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
for (auto node : nodes) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) continue;
|
||||
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
if (!common::AnfAlgo::IsGraphKernel(node)) continue;
|
||||
auto graph_kernel_fg = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
|
||||
changed = DoFuse(graph_kernel_fg) || changed;
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/pass/getitem_tuple.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "ir/graph_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "utils/anf_utils.h"
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <utility>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/common/optimizer/helper.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
@ -52,16 +53,16 @@ CNodePtr AddCastCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, c
|
|||
cast->set_kernel_info(kernel_info);
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), cast);
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDstType, TypeIdToType(output_type), cast);
|
||||
common::AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
|
||||
return cast;
|
||||
}
|
||||
|
||||
// Update Output Abatract and BuildInfo as Input Changed
|
||||
void UpdateOutputInfo(const AnfNodePtr &cnode) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||
ShapeVector out_shape = GetShape(cnode);
|
||||
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(out_shape));
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
|
||||
|
@ -81,13 +82,13 @@ void UpdateOutputInfo(const AnfNodePtr &cnode) {
|
|||
CNodePtr InsertCastForGraphKernel(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto mng = func_graph->manager();
|
||||
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
|
||||
size_t in_num = common::AnfAlgo::GetInputNum(cnode); // include monads.
|
||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
auto cur_input = common::AnfAlgo::GetInputNode(cnode, input_index);
|
||||
if (HasAbstractMonad(cur_input)) {
|
||||
continue;
|
||||
}
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
||||
auto prev_node = common::AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
||||
auto in_node = prev_node.first;
|
||||
auto in_index = prev_node.second;
|
||||
auto ori_shape = AnfAlgo::GetOutputDeviceShape(in_node, in_index);
|
||||
|
@ -115,7 +116,7 @@ CNodePtr InsertCastForGraphKernel(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
auto abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(TypeId::kNumberTypeFloat16), abs_shape_ptr);
|
||||
cast->set_abstract(abstract);
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
||||
(void)mng->Replace(cur_input, cast);
|
||||
}
|
||||
}
|
||||
|
@ -136,7 +137,7 @@ bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) {
|
|||
bool changed = false;
|
||||
// Cast Down CNODES
|
||||
for (auto node : todos) {
|
||||
if (node->isa<CNode>() && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
if (node->isa<CNode>() && !common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimCast)) {
|
||||
if (AnfAlgo::GetOutputDeviceDataType(cnode->input(1), 0) == kNumberTypeFloat16) {
|
||||
|
@ -197,9 +198,9 @@ bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) {
|
|||
};
|
||||
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
if (AnfAlgo::CheckPrimitiveType(old_output, prim::kPrimMakeTuple)) {
|
||||
if (common::AnfAlgo::CheckPrimitiveType(old_output, prim::kPrimMakeTuple)) {
|
||||
(void)new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
auto all_out = AnfAlgo::GetAllOutput(old_output);
|
||||
auto all_out = common::AnfAlgo::GetAllOutput(old_output);
|
||||
for (const auto &out : all_out) {
|
||||
auto c_out = out->cast<CNodePtr>();
|
||||
if (c_out) {
|
||||
|
@ -229,7 +230,7 @@ bool IsCastUnAware(const FuncGraphPtr &func_graph) {
|
|||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (node->isa<CNode>()) {
|
||||
if (std::find(cast_aware_list.begin(), cast_aware_list.end(), AnfAlgo::GetCNodePrimitive(node)) !=
|
||||
if (std::find(cast_aware_list.begin(), cast_aware_list.end(), common::AnfAlgo::GetCNodePrimitive(node)) !=
|
||||
cast_aware_list.end()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -251,8 +252,8 @@ bool DecreaseComputePrecision::Run(const FuncGraphPtr &func_graph) {
|
|||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
for (const auto &node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
if (common::AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_ERROR_IF_NULL(sub_func_graph);
|
||||
if (IsCastUnAware(sub_func_graph)) {
|
||||
changed = Process(sub_func_graph) || changed;
|
||||
|
|
|
@ -20,10 +20,8 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/manager.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
|
@ -41,15 +39,15 @@ int64_t ObtainGetItemIndex(const AnfNodePtr &getitem) {
|
|||
}
|
||||
|
||||
bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out, size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
if (is_tuple_out) {
|
||||
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << AnfAlgo::GetCNodeName(tuple_output);
|
||||
if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
|
||||
}
|
||||
auto input_node = tuple_output->input(index + 1);
|
||||
if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
|
||||
if (common::AnfAlgo::GetCNodeName(input_node) == prim::kPrimReduceSum->name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -57,17 +55,17 @@ bool IsPreNodeReduce(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple
|
|||
}
|
||||
|
||||
size_t GetGraphKernelSize(const AnfNodePtr &node) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
return gk_graph->GetOrderedCnodes().size();
|
||||
}
|
||||
|
||||
bool IsCandidateNode(const AnfNodePtr &node) {
|
||||
bool is_gk = AnfAlgo::IsGraphKernel(node);
|
||||
bool is_gk = common::AnfAlgo::IsGraphKernel(node);
|
||||
if (is_gk) {
|
||||
auto num = GetGraphKernelSize(node);
|
||||
if (num > GK_MIN_SIZE) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
if (graph_name.find("atomic") == std::string::npos) {
|
||||
return true;
|
||||
|
@ -144,7 +142,7 @@ bool DecreaseTransferPrecision::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNodePtr &node, bool is_tuple_out,
|
||||
size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
auto mng = gk_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
@ -179,7 +177,7 @@ bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNod
|
|||
if (!is_tuple_out) {
|
||||
auto old_output = gk_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_output);
|
||||
if (AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
|
||||
if (common::AnfAlgo::GetCNodeName(old_output) == prim::kPrimCast->name() &&
|
||||
AnfAlgo::GetInputDeviceDataType(old_output, 0) == kNumberTypeFloat16 &&
|
||||
AnfAlgo::GetOutputDeviceDataType(old_output, 0) == kNumberTypeFloat32) {
|
||||
auto real_output = old_output->input(1);
|
||||
|
@ -200,8 +198,8 @@ bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNod
|
|||
} else {
|
||||
// cast for graph kernel with make tuple output
|
||||
auto tuple_output = gk_graph->output()->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << AnfAlgo::GetCNodeName(tuple_output);
|
||||
if (common::AnfAlgo::GetCNodeName(tuple_output) != prim::kPrimMakeTuple->name()) {
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(tuple_output);
|
||||
}
|
||||
auto input_node = tuple_output->input(index + 1);
|
||||
auto cnode = func_add_cast_fp16(input_node);
|
||||
|
@ -234,7 +232,7 @@ bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNod
|
|||
}
|
||||
|
||||
bool DecreaseTransferPrecision::ProcessSon(const FuncGraphPtr &, const AnfNodePtr &node, size_t index) {
|
||||
auto gk_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
auto mng = gk_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "common/graph_kernel/depend_elimination.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
bool DependElimination::Run(const FuncGraphPtr &func_graph) {
|
||||
|
|
|
@ -19,17 +19,16 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
bool IsCNodePrimitveEqual(const CNodePtr &main, const CNodePtr &node, const std::vector<PrimitivePtr> &black_list) {
|
||||
auto main_primitive = AnfAlgo::GetCNodePrimitive(main);
|
||||
auto node_primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
auto main_primitive = common::AnfAlgo::GetCNodePrimitive(main);
|
||||
auto node_primitive = common::AnfAlgo::GetCNodePrimitive(node);
|
||||
if (main_primitive != nullptr && node_primitive != nullptr) {
|
||||
// Some ops such as Reshape is not real op, cse these type will not get gain. And for ops fusion, keep these op
|
||||
// alone can prevent some redundant output case (input -> reshape -> output).
|
||||
|
@ -62,7 +61,7 @@ bool GraphKernelBackendCSE::CheckEqualKernelBuildInfo(const AnfNodePtr &main, co
|
|||
MS_EXCEPTION_IF_NULL(main);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!AnfAlgo::IsNodeInGraphKernel(main)) {
|
||||
if (!common::AnfAlgo::IsNodeInGraphKernel(main)) {
|
||||
return BackendCSE::CheckEqualKernelBuildInfo(main, node);
|
||||
}
|
||||
|
||||
|
@ -98,7 +97,7 @@ bool GraphKernelBackendCSE::CheckEqualCnodeInputs(const AnfNodePtr &main, const
|
|||
auto c_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
|
||||
if (!AnfAlgo::IsNodeInGraphKernel(c_main)) {
|
||||
if (!common::AnfAlgo::IsNodeInGraphKernel(c_main)) {
|
||||
return BackendCSE::CheckEqualCnodeInputs(main, node);
|
||||
}
|
||||
|
||||
|
|
|
@ -27,13 +27,14 @@
|
|||
#include "kernel/akg/akg_kernel_json_decoder.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "common/graph_kernel/adapter/fake_abstract_shape.h"
|
||||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -47,7 +48,7 @@ bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
|
||||
if (auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
|
||||
auto fg_out = fg->output();
|
||||
if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
|
||||
auto inputs = fg_out->cast<CNodePtr>()->inputs();
|
||||
|
@ -75,7 +76,7 @@ bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, Anf
|
|||
}
|
||||
std::string fused_name;
|
||||
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
|
||||
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
|
||||
(void)fused_name.append(common::AnfAlgo::GetCNodeName(node)).append("_");
|
||||
});
|
||||
MS_LOG(DEBUG) << "Collect fusion json: " << fused_name;
|
||||
return true;
|
||||
|
@ -130,7 +131,7 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
std::vector<std::string> graph_output_format;
|
||||
std::vector<TypeId> graph_output_type;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
if (kernel_with_index.first->isa<ValueNode>()) {
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
@ -153,14 +154,14 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i].get());
|
||||
}
|
||||
auto new_outputs = outputs;
|
||||
if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
|
||||
if (outputs.size() == 1 && common::AnfAlgo::IsGraphKernel(outputs[0])) {
|
||||
std::vector<AnfNodePtr> real_outs;
|
||||
if (IsMakeTupleOut(outputs[0], &real_outs)) {
|
||||
new_outputs = real_outs;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < new_outputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(new_outputs[i], 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_output_format.push_back(output_format);
|
||||
|
@ -185,13 +186,13 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
MS_LOG(ERROR) << "Input nodes is empty.";
|
||||
return false;
|
||||
}
|
||||
bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel);
|
||||
bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), common::AnfAlgo::IsGraphKernel);
|
||||
bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;
|
||||
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList op_nodes, inputs, outputs;
|
||||
if (is_single_graph_kernel) {
|
||||
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||
fg = common::AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
|
||||
} else if (!has_graph_kernel) {
|
||||
std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes);
|
||||
|
@ -215,8 +216,8 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
|
||||
FuncGraphPtr fg;
|
||||
|
||||
if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) {
|
||||
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||
if (nodes.size() == 1 && common::AnfAlgo::IsGraphKernel(nodes[0])) {
|
||||
fg = common::AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||
} else {
|
||||
std::tie(fg, std::ignore, std::ignore) = BuildSingleGraphFromNodes(nodes);
|
||||
}
|
||||
|
@ -385,7 +386,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
|
|||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
input_formats.push_back(input_format);
|
||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
|
@ -416,12 +417,12 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
|
|||
if (cnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())};
|
||||
AnfNodePtrList new_inputs = {NewValueNode(common::AnfAlgo::GetCNodePrimitive(cnode)->Clone())};
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end());
|
||||
cnode->set_inputs(new_inputs);
|
||||
|
||||
// Set attr secondly.
|
||||
AnfAlgo::SetNodeAttr(key, value, node);
|
||||
common::AnfAlgo::SetNodeAttr(key, value, node);
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
|
|
|
@ -147,7 +147,7 @@ OrderedSet<AnfNodePtr> GetLongTermNodes(const AnfNodePtrList &nodes, const AnfNo
|
|||
const FuncGraphManagerPtr &mng) {
|
||||
OrderedSet<AnfNodePtr> long_term_nodes;
|
||||
for (auto node : nodes) {
|
||||
auto real_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first;
|
||||
auto real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0).first;
|
||||
// Parameter or value have long term tensors.
|
||||
if (!utils::isa<CNodePtr>(real_node)) {
|
||||
(void)long_term_nodes.insert(node);
|
||||
|
@ -225,10 +225,10 @@ AnfNodePtrList AutoRecompute::Filter(const AnfNodePtr &source_node, const AnfNod
|
|||
AnfNodePtrList check_inputs;
|
||||
if (IsPrimitiveCNode(end_node->cast<CNodePtr>()->input(IntToSize(edge_pos)), prim::kPrimTupleGetItem)) {
|
||||
auto out_index = GetSourceLinkOutPos(end_node, edge_pos);
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(source_node);
|
||||
auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(source_node);
|
||||
auto out = sub_graph->output();
|
||||
if (!IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << AnfAlgo::GetCNodeName(out);
|
||||
MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(out);
|
||||
}
|
||||
|
||||
// Find subgraph's input according to edge node.
|
||||
|
@ -345,7 +345,7 @@ OutPosLinkList AutoRecompute::JudegeTargetAndCaptureSource(const AnfNodePtr &nod
|
|||
// Direct users include long term users and short term users.
|
||||
// If the short term user is graph kernel composite node, it may be absorb and reduce the local peak memory.
|
||||
for (auto user : direct_users) {
|
||||
if (long_term_users.count(user) == 0 && AnfAlgo::IsGraphKernel(user)) {
|
||||
if (long_term_users.count(user) == 0 && common::AnfAlgo::IsGraphKernel(user)) {
|
||||
(void)target_link_infos.emplace_back(user, user_edge_pos[user], EdgeLifeTimeType::ShortTerm);
|
||||
}
|
||||
}
|
||||
|
@ -441,7 +441,7 @@ void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
|
|||
// 2. Memory variety between split out and origin more than threshold:
|
||||
// `Size(gs_direct_outs_to_gt) - filter(gs_inputs, its) > threshold`.
|
||||
for (auto node : topo_nodes) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) {
|
||||
if (!common::AnfAlgo::IsGraphKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto target_graphs = JudegeTargetAndCaptureSource(node, mng);
|
||||
|
@ -544,7 +544,7 @@ void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candida
|
|||
std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
|
||||
const AnfNodePtrList &recompute_edges) {
|
||||
MS_EXCEPTION_IF_NULL(source_graph);
|
||||
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
|
||||
auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
|
||||
MS_EXCEPTION_IF_NULL(gs);
|
||||
AnfNodePtrList inputs(source_graph->inputs().begin() + 1, source_graph->inputs().end());
|
||||
auto new_funcgraph = BasicClone(gs);
|
||||
|
@ -576,7 +576,7 @@ void GraphKernelRecompute::LinkIntoTargetFuncGraph(
|
|||
const Candidate &candidate, const FuncGraphPtr &cloned_func, const AnfNodePtrList &cloned_inputs,
|
||||
const std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> &edge_match_func) {
|
||||
auto cloned_nodes = TopoSort(cloned_func->get_return());
|
||||
auto gt = AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
|
||||
auto gt = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
|
||||
MS_EXCEPTION_IF_NULL(gt);
|
||||
auto mng = gt->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -639,7 +639,7 @@ void GraphKernelRecompute::Process(const Candidate &candidate) {
|
|||
std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> edge_match_func;
|
||||
if (candidate.recompute_edges.empty()) {
|
||||
// single output, clone the whole source_graph.
|
||||
auto gs = AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
|
||||
auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
|
||||
MS_EXCEPTION_IF_NULL(gs);
|
||||
new_funcgraph = BasicClone(gs);
|
||||
auto source_cnode = candidate.source_graph->cast<CNodePtr>();
|
||||
|
@ -669,7 +669,7 @@ void GraphKernelRecompute::Process(const Candidate &candidate) {
|
|||
new_funcgraph->set_manager(mng);
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsGraphKernel(candidate.target_graph)) {
|
||||
if (common::AnfAlgo::IsGraphKernel(candidate.target_graph)) {
|
||||
// the target graph is a GraphKernel, push the new_funcgraph into the target graph.
|
||||
LinkIntoTargetFuncGraph(candidate, new_funcgraph, inputs, edge_match_func);
|
||||
} else {
|
||||
|
@ -690,7 +690,7 @@ bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
|
|||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (auto &c : candidates_) {
|
||||
if (!AnfAlgo::IsGraphKernel(c.target_graph)) {
|
||||
if (!common::AnfAlgo::IsGraphKernel(c.target_graph)) {
|
||||
continue;
|
||||
}
|
||||
std::ostringstream oss;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "common/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -108,7 +108,7 @@ bool IsAkgMatMul(size_t K, size_t M, size_t N) {
|
|||
// Return ture if (K, M, N) need pad
|
||||
std::tuple<bool, bool, bool> NeedPad(const CNodePtr &matmul, vec *pad_shape_a, vec *pad_shape_b, vec *unpad_shape,
|
||||
vec *tail_shape_a, vec *tail_shape_b, vec *tail_shape_unpad) {
|
||||
auto mm_attrs = AnfAlgo::GetCNodePrimitive(matmul)->attrs();
|
||||
auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
|
||||
if (mm_attrs.count("transpose_a") == 0 || mm_attrs.count("transpose_b") == 0) {
|
||||
MS_LOG(ERROR) << "attrs transpose_a and transpose_b need to be set in node " << matmul->fullname_with_scope();
|
||||
return std::tuple(false, false, false);
|
||||
|
@ -164,7 +164,7 @@ void InsertPad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, const Fun
|
|||
SetNodeAttrSafely("head", MakeValue(head), pad_cnode);
|
||||
SetNodeAttrSafely("tail", MakeValue(tail), pad_cnode);
|
||||
SetNodeAttrSafely("pad_val", MakeValue(std::make_shared<Int32Imm>(0)), pad_cnode);
|
||||
std::vector<TypeId> pad_type = {AnfAlgo::GetPrevNodeOutputInferDataType(matmul, 0)};
|
||||
std::vector<TypeId> pad_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(matmul, 0)};
|
||||
|
||||
ShapeVector abs_shape;
|
||||
(void)abs_shape.insert(abs_shape.begin(), pad_shape.begin(), pad_shape.end());
|
||||
|
@ -194,7 +194,7 @@ void InsertUnpad(const CNodePtr &matmul, const FuncGraphPtr &func_graph, const F
|
|||
ShapeVector tail;
|
||||
(void)tail.insert(tail.begin(), tail_shape.begin(), tail_shape.end());
|
||||
SetNodeAttrSafely("tail", MakeValue(tail), unpad_cnode);
|
||||
std::vector<TypeId> unpad_type = {AnfAlgo::GetOutputInferDataType(matmul, 0)};
|
||||
std::vector<TypeId> unpad_type = {common::AnfAlgo::GetOutputInferDataType(matmul, 0)};
|
||||
|
||||
ShapeVector abs_shape;
|
||||
(void)abs_shape.insert(abs_shape.begin(), unpad_shape.begin(), unpad_shape.end());
|
||||
|
@ -221,7 +221,7 @@ void UpdateMatmulInfo(const AnfNodePtr &matmul_node, const vec &unpad_shape, con
|
|||
abs_shape.push_back(unpad_shape[i] + tail_shape[i]);
|
||||
}
|
||||
auto abs_shape_ptr = std::make_shared<abstract::Shape>(abstract::Shape(abs_shape));
|
||||
TypeId abs_type = AnfAlgo::GetOutputInferDataType(matmul_node, 0);
|
||||
TypeId abs_type = common::AnfAlgo::GetOutputInferDataType(matmul_node, 0);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(abs_type), abs_shape_ptr);
|
||||
matmul_node->set_abstract(abstract);
|
||||
|
||||
|
@ -240,7 +240,7 @@ bool InsertPadUnpad(const FuncGraphPtr &func_graph) {
|
|||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
for (const auto &n : todos) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
|
||||
if (!common::AnfAlgo::CheckPrimitiveType(n, prim::kPrimMatMul)) continue;
|
||||
auto mm_cnode = n->cast<CNodePtr>();
|
||||
vec pad_shape_a, pad_shape_b, tail_shape_a, tail_shape_b, tail_shape_unpad, unpad_shape;
|
||||
bool pad_K{false}, pad_M{false}, pad_N{false};
|
||||
|
@ -283,8 +283,8 @@ bool InsertPadOps::Run(const FuncGraphPtr &func_graph) {
|
|||
auto changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
for (auto node : nodes) {
|
||||
if (!AnfAlgo::IsGraphKernel(node)) continue;
|
||||
auto graph_kernel_fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
if (!common::AnfAlgo::IsGraphKernel(node)) continue;
|
||||
auto graph_kernel_fg = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(graph_kernel_fg);
|
||||
changed = InsertPadUnpad(graph_kernel_fg) || changed;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue