forked from OSSInnovation/mindspore
Support Common Hccl Op
1.Support Broadcast op 2.Support communication op as graph output 3.Optimize Communication op memory alocation 4.support hccl multi-group
This commit is contained in:
parent
93f6fc0ab0
commit
7d07e17f5a
|
@ -9,11 +9,11 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake)
|
|||
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake)
|
||||
include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
|
||||
|
||||
# for CPU/GPU mode, find c_sec and slog from local prebuild
|
||||
# for CPU/GPU mode, find slog from local prebuild
|
||||
if (NOT ENABLE_D)
|
||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH})
|
||||
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
||||
elseif (DEFINED ENV{D_LINK_PATH})
|
||||
set(GE_LIB_PATH $ENV{D_LINK_PATH})
|
||||
|
@ -28,7 +28,6 @@ elseif (DEFINED ENV{D_LINK_PATH})
|
|||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
|
||||
endif()
|
||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
|
||||
find_library(c_sec libc_sec.so ${GE_LIB_PATH})
|
||||
find_library(slog libslog.so ${GE_LIB_PATH})
|
||||
find_library(mmpa libmmpa.so ${GE_LIB_PATH})
|
||||
find_library(runtime libruntime.so ${GE_LIB_PATH})
|
||||
|
|
|
@ -153,7 +153,7 @@ if (NOT ENABLE_GE)
|
|||
FILES
|
||||
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
||||
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
|
||||
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libc_sec.so
|
||||
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
||||
DESTINATION ${INSTALL_LIB_DIR}
|
||||
COMPONENT mindspore
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 995b6dadc0fbbe4b80a08196886a53a18bffa60e
|
||||
Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23
|
|
@ -333,8 +333,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first,
|
||||
model_iter->second, listener);
|
||||
if (!status) {
|
||||
MS_LOG(ERROR) << "load task failed";
|
||||
return false;
|
||||
MS_LOG(EXCEPTION) << "Load Task Failed";
|
||||
}
|
||||
if (ProfilingManager::GetInstance().IsProfiling()) {
|
||||
auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first);
|
||||
|
|
|
@ -29,6 +29,7 @@ class GraphDescReporter : public DescReporter {
|
|||
public:
|
||||
GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector<CNodePtr> cnode_list)
|
||||
: DescReporter(device_id, file_name, std::move(cnode_list)) {}
|
||||
~GraphDescReporter() override = default;
|
||||
void ReportData() override;
|
||||
};
|
||||
} // namespace ascend
|
||||
|
|
|
@ -60,7 +60,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_broadcast(tag_broadcast.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<u32>(task_info->root_id()), nullptr, stream);
|
||||
static_cast<u32>(task_info->root_id()), task_info->group().c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
|
||||
return false;
|
||||
|
@ -70,7 +70,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_all_gather(tag_all_gather.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
|
||||
static_cast<hcclDataType_t>(task_info->data_type()), nullptr, stream);
|
||||
static_cast<hcclDataType_t>(task_info->data_type()), task_info->group().c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
|
||||
return false;
|
||||
|
@ -81,7 +81,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
ret = hcom_all_reduce(tag_all_reduce.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()), static_cast<u64>(task_info->count()),
|
||||
static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream);
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
|
||||
return false;
|
||||
|
@ -93,7 +93,7 @@ bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info
|
|||
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), reinterpret_cast<void *>(task_info->input_data_addr()),
|
||||
reinterpret_cast<void *>(task_info->output_data_addr()),
|
||||
static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), nullptr, stream);
|
||||
static_cast<hcclRedOp_t>(task_info->op_type()), task_info->group().c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
|
||||
return false;
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "device/kernel_runtime.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
@ -130,20 +131,16 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|||
mem_manager_->ResetDynamicMemory();
|
||||
AssignStaticMemory(graph);
|
||||
AssignDynamicMemory(graph);
|
||||
|
||||
UpdateRefNodeOutputMem(graph);
|
||||
}
|
||||
|
||||
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
||||
session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// assign memory for input nodes
|
||||
RunOpAssignInputMemory(input_tensors, graph);
|
||||
AssignStaticMemoryValueNode(graph);
|
||||
for (const auto &cnode : graph->execution_order()) {
|
||||
// assign memory for output nodes
|
||||
RunOpAssignOutputMemory(cnode);
|
||||
// assign memory for workspace
|
||||
RunOpAssignWorkSpaceMemory(cnode);
|
||||
}
|
||||
UpdateRefNodeOutputMem(graph);
|
||||
|
@ -280,12 +277,22 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|||
void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
std::vector<session::KernelWithIndex> non_communication_op;
|
||||
// Assign Communicate Op Memory firstly.
|
||||
for (const auto &node : nodes) {
|
||||
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
||||
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::IsCommunicationOp(item_with_index.first)) {
|
||||
AssignCommunicationNodeMem(kStaticMem, item_with_index.first);
|
||||
} else {
|
||||
non_communication_op.emplace_back(item_with_index);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &item_with_index : non_communication_op) {
|
||||
AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
|
||||
}
|
||||
}
|
||||
|
@ -322,6 +329,11 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
|
||||
AssignCommunicationNodeInputMem(node);
|
||||
AssignCommunicationNodeOutputMem(flag, node);
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
|
@ -335,8 +347,13 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
size_t total_size = 0;
|
||||
size_t output_index = 0;
|
||||
std::vector<size_t> align_size_list;
|
||||
for (uint64_t mem_size : output_sizes) {
|
||||
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
|
||||
MS_LOG(INFO) << "communication op addr exist";
|
||||
continue;
|
||||
}
|
||||
if (context_ptr->enable_hccl()) {
|
||||
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
||||
}
|
||||
|
@ -353,7 +370,21 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
||||
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
if (output_sizes.size() <= index) {
|
||||
MS_LOG(EXCEPTION) << "Previous node output size < node index";
|
||||
}
|
||||
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
|
||||
auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type);
|
||||
AnfAlgo::SetOutputAddr(address, index, anf_node.get());
|
||||
return address;
|
||||
}
|
||||
|
||||
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -361,12 +392,16 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
|||
size_t total_size = 0;
|
||||
std::vector<std::pair<mindspore::device::DeviceAddress *, size_t>> addr_size;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
|
||||
auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i);
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
auto mem_size = address->size();
|
||||
if (context_ptr->enable_hccl()) {
|
||||
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
DeviceAddressPtr address = nullptr;
|
||||
if (input_node->isa<CNode>()) {
|
||||
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
auto mem_size = mem_manager_->GetCommonAlignSize(address->size());
|
||||
total_size += mem_size;
|
||||
addr_size.emplace_back(address.get(), mem_size);
|
||||
}
|
||||
|
@ -381,11 +416,6 @@ void KernelRuntime::UpdateCommunicationOpInputMem(const AnfNodePtr &node) {
|
|||
void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
UpdateCommunicationOpInputMem(node);
|
||||
AssignCommunicationNodeOutputMem(flag, node);
|
||||
return;
|
||||
}
|
||||
if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) {
|
||||
MS_LOG(INFO) << "GetNext disable mem_reuse";
|
||||
flag = kDynamicMem;
|
||||
|
@ -506,10 +536,22 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
|
|||
mem_manager_->MallocReusedDynamicMem(graph);
|
||||
mem_flag = kReuseDynamicMem;
|
||||
}
|
||||
auto &kernels = graph->execution_order();
|
||||
for (auto &kernel : kernels) {
|
||||
AssignNodeOutputMem(mem_flag, kernel, kGetAllOuts);
|
||||
AssignWorkSpaceMem(mem_flag, kernel);
|
||||
auto &execution_nodes = graph->execution_order();
|
||||
std::vector<CNodePtr> compute_nodes;
|
||||
// communication nodes first
|
||||
for (auto &node : execution_nodes) {
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
// skip if the memory is already alocated
|
||||
AssignCommunicationNodeMem(mem_flag, node);
|
||||
} else {
|
||||
compute_nodes.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
// then compute nodes
|
||||
for (auto &node : compute_nodes) {
|
||||
AssignNodeOutputMem(mem_flag, node, kGetAllOuts);
|
||||
AssignWorkSpaceMem(mem_flag, node);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -73,9 +73,12 @@ class KernelRuntime {
|
|||
void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index);
|
||||
void AssignWorkSpaceMem(int flag, const AnfNodePtr &node);
|
||||
void AssignReuseWorkSpaceMem(const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
|
||||
|
||||
void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
|
||||
void UpdateCommunicationOpInputMem(const AnfNodePtr &node);
|
||||
|
||||
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeInputMem(const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
|
||||
#ifdef ENABLE_DUMP_E2E
|
||||
bool SetDumpConf();
|
||||
#endif
|
||||
|
@ -91,6 +94,7 @@ class KernelRuntime {
|
|||
void RunOpAssignOutputMemory(const AnfNodePtr &kernel);
|
||||
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
|
||||
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
||||
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
|
||||
|
||||
protected:
|
||||
uint32_t device_id_{0};
|
||||
|
|
|
@ -90,6 +90,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_));
|
||||
anf_node_ = anf_node;
|
||||
return true;
|
||||
}
|
||||
|
@ -147,7 +148,7 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
|
||||
HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>(
|
||||
stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr,
|
||||
hccl_count_, root_id_, op_type_, data_type, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
|
||||
hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
|
||||
RuntimeUtils::HcomDistribute);
|
||||
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
||||
return {task_info_ptr};
|
||||
|
|
|
@ -54,6 +54,7 @@ class HcclKernel : public AscendKernelMod {
|
|||
mutable std::vector<size_t> workspace_size_list_;
|
||||
AnfNodePtr anf_node_;
|
||||
std::string op_name_;
|
||||
std::string group_;
|
||||
};
|
||||
|
||||
using HcclKernelCreater = std::function<std::shared_ptr<HcclKernel>()>;
|
||||
|
|
|
@ -176,11 +176,22 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
|
|||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
if (primitive->GetAttr("root_rank") != nullptr) {
|
||||
*root_id = GetValue<const vector<uint32_t>>(primitive->GetAttr("root_rank"))[0];
|
||||
*root_id = (uint32_t)GetValue<int>(primitive->GetAttr("root_rank"));
|
||||
} else {
|
||||
MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto attr = primitive->GetAttr("group");
|
||||
if (attr != nullptr) {
|
||||
*group = GetValue<std::string>(attr);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed";
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <memory>
|
||||
#include "ir/dtype.h"
|
||||
#include "hccl/base.h"
|
||||
#include "utils/contract.h"
|
||||
|
||||
namespace mindspore {
|
||||
using std::map;
|
||||
|
@ -61,6 +62,7 @@ class HcomUtil {
|
|||
const vector<vector<size_t>> &shape_list, uint64_t *total_count);
|
||||
static bool GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type);
|
||||
static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
|
||||
static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -66,8 +66,7 @@ const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const A
|
|||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != kAllReduceOpName && op_name != kAllGatherOpName && op_name != kReduceScatterOpName) {
|
||||
if (!AnfAlgo::IsCommunicationOp(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode);
|
||||
|
|
|
@ -173,6 +173,19 @@ const BaseRef DealRefTransAndCast::DefinePattern() const {
|
|||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope();
|
||||
AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
|
@ -184,6 +197,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
|
|||
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DealBroadCastAsRef(graph, cnode);
|
||||
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
|
||||
if (op_info == nullptr || !op_info->is_ref()) {
|
||||
|
|
Loading…
Reference in New Issue