[auto-monad] Support side-effects by auto-monad

The basic idea is: exploits data dependency to control the execution order
of side-effect operations, and keep the semantics of ANF unchanged.

The ControlDepend primitive is removed and there are two primitives added:

1. UpdateState:
```
  a = Assign(para, value)
```
became:
```
  a = Assign(para, value, u)
  u = UpdateState(u, a)
```

2. Load:
```
  x = Add(para, value)
```
became:
```
  p = Load(para, u)
  x = Add(p, value)
  u = UpdateState(u, p)
```
This commit is contained in:
He Wei 2021-02-05 11:54:29 +08:00
parent f0a9cb7c20
commit 7d9a783993
342 changed files with 13734 additions and 2841 deletions

View File

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.14.1)
cmake_minimum_required(VERSION 3.14.0)
project(MindSpore)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
@ -14,18 +14,25 @@ endif()
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_OSX_SYSROOT "")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings \
-Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \
-Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \
-Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
else()
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
endif()
if(ENABLE_PYTHON)
add_compile_definitions(ENABLE_PYTHON)
endif()
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \
-Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 \
-Werror -Wall -Wno-deprecated-declarations -fPIC")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(PYBIND11_CPP_STANDARD -std=c++17)

View File

@ -132,6 +132,16 @@ def Depend(value, expr):
return value
def UpdateState(monad, expr):
"""Implement `UpdateState`."""
return monad
def Load(value, u=None):
"""Implement `Load`."""
return value
# only used in PyNative mode
def make_ref(key, value, ref):
return value

View File

@ -42,14 +42,16 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}

View File

@ -139,9 +139,9 @@ bool CheckCache(const std::string &kernel_name) {
std::string kernel_json = bin_map->Search(kernel_name);
bool ret = (!kernel_json.empty());
if (ret) {
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
} else {
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
}
return ret;
}
@ -730,30 +730,6 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann:
return false;
}
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_list);
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::IsRealKernel(output)) {
// single output.
node_list->push_back(std::make_pair(output, 0));
return;
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
// multi output.
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
node_list->push_back(in_with_idx);
}
return;
}
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
<< " of graph: " << func_graph->ToString();
}
bool IsWeightBoundary(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return true;
@ -776,7 +752,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(primitive);
auto axis_attr = primitive->GetAttr(kAxis);
if (axis_attr == nullptr) {
MS_LOG(ERROR) << "This node does't have axie attr.";
MS_LOG(ERROR) << "This node doesn't have axie attr.";
return std::vector<int64_t>();
}
std::vector<int64_t> axis_list;

View File

@ -181,7 +181,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
std::vector<size_t> out_shape;
out_shape.emplace_back(miss_count);
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
for (size_t i = 0; i < output_num; i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},

View File

@ -69,7 +69,8 @@ void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
std::vector<size_t> out_shape;
out_shape.emplace_back(count);
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
for (size_t i = 0; i < output_num; i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape, out_shape}, node_.get());

View File

@ -29,5 +29,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AssignGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AssignGpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -63,13 +63,15 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
for (const auto &type : kHcclSupportTypes) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
inputs_type.push_back(type);
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) {
outputs_format.emplace_back(GetKernelFormat(kernel_node, 0));
} else {

View File

@ -31,7 +31,8 @@ bool IsPyNativeMode() {
bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
for (size_t i = 0; i < input_num; ++i) {
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
hccl_kernel_intput_shape_list->emplace_back(shape_i);
}
@ -42,7 +43,8 @@ bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<siz
bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list);
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) {
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
for (size_t i = 0; i < output_num; ++i) {
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
hccl_kernel_output_shape_list->emplace_back(shape_i);
}
@ -53,11 +55,12 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(data_type_list);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
for (size_t i = 0; i < input_num; ++i) {
auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i);
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr;
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr;
}
data_type_list->emplace_back(iter->second);
}

View File

@ -37,13 +37,15 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}

View File

@ -30,7 +30,7 @@ std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
if (output_index >= outputs_format_.size()) {
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
return kInvalidFormat;
}
return outputs_format_[output_index];

View File

@ -86,6 +86,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
// LabelSwitch always return UMonad.
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
label_switch_build_info.emplace_back(builder.Build());
}
return label_switch_build_info;

View File

@ -74,11 +74,10 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
}
kernel_build_info_builder->SetInputsDeviceType(input_types);
// set output info
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
// set ohter info
// Kernel ops in while-list such as 'LabelSet' always return UMonad.
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
// set other info
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);

View File

@ -1052,10 +1052,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
auto node_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode);
MS_EXCEPTION_IF_NULL(cnode);
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
auto node_inputs_size = cnode->inputs().size();
for (auto &input : cnode->inputs()) {
if (HasAbstractMonad(input)) {
node_inputs_size--;
}
}
if (op_info->inputs_ptr().size() < (node_inputs_size - 1)) {
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
}
return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
return (op_info->inputs_ptr().size() + 1 - node_inputs_size);
}
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
@ -1103,6 +1109,9 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
bool is_dynamic_input = IsDynamicInput(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input = cnode->input(i);
if (HasAbstractMonad(input)) {
continue;
}
auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
auto real_node = kernel_idx.first;
size_t real_idx = kernel_idx.second;

View File

@ -112,6 +112,10 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
auto input_node = AnfAlgo::GetInputNode(node, index);
if (HasAbstractMonad(input_node)) {
// No transfer for monad inputs.
return input_node;
}
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
@ -330,8 +334,9 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
for (size_t input_index = 0; input_index < in_num; ++input_index) {
// Monad inputs keep unchanged from GetTransInputNodePtr().
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
new_inputs.push_back(input_node);
@ -352,12 +357,18 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
for (size_t input_index = 0; input_index < in_num; ++input_index) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
if (HasAbstractMonad(cur_input)) {
// No cast for monad inputs.
new_inputs.push_back(cur_input);
continue;
}
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
TypeId origin_type(kTypeUnknown);
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
auto real_input_node = kernel_with_index.first;
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {

View File

@ -244,7 +244,9 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(),
(*buffer_fusion_infos)[fusion_id].inputs_list.end(),
in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) {
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in);
if (!HasAbstractMonad(in)) {
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in);
}
}
}
}

View File

@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
return real_node->isa<ValueNode>();
}
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
const std::vector<AnfNodePtr> &memcpy_async_list) {
MS_EXCEPTION_IF_NULL(control_depend);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
make_tuple_inputs.emplace_back(hccl_node);
auto make_tuple = graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
control_depend->set_input(IntToSize(index), make_tuple);
}
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
const std::vector<AnfNodePtr> &memcpy_async_list) {
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(tuple_getitem);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"
<< " trace: " << trace::DumpSourceLines(hccl_node);
}
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
}
}
}
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(hccl_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(hccl_node);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"
<< " trace: " << trace::DumpSourceLines(hccl_node);
}
// find hccl_node's output which is a control depend
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list);
}
}
}
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
if (node_users.size() == 1) {
@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> memcpy_async_list;
bool need_memcpy_async = false;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
need_memcpy_async = true;
} else {
new_inputs.push_back(input);
}
}
if (!memcpy_async_list.empty()) {
if (need_memcpy_async) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager();
@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
(void)manager->Replace(hccl_node, new_hccl_node);
MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async
TransferControl(new_hccl_node, memcpy_async_list, graph);
}
}

View File

@ -57,7 +57,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
return nullptr;
}
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) {
for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx);
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx);

View File

@ -40,7 +40,8 @@ const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, cons
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<CNode>()) {
@ -77,7 +78,8 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
MS_EXCEPTION_IF_NULL(node_info.first);
auto cast_out_node = node_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_out_node);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
for (size_t index = 0; index < input_num; ++index) {
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
cast_node) {
continue;

View File

@ -162,7 +162,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
// deal with ref output
if (ref_infos.count(output_index) != 0) {

View File

@ -37,7 +37,7 @@ const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(node);
auto split_v = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(split_v);
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3);
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), kMatMulInputTensorNum);
MS_EXCEPTION_IF_NULL(matmul);
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0);
auto input_node = input_node_with_idx.first;

View File

@ -129,9 +129,21 @@ AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
kernel::GetValidKernelNodes(sub_graph, &todo);
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
for (auto &output : outputs) {
size_t index = 0;
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(tuple_index_value);
if (!tuple_index_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
}
index = tuple_index_value->cast<Int64ImmPtr>()->value();
}
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
}
for (auto &t : todo) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
// process input

View File

@ -33,7 +33,8 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t index = 0; index < input_num; ++index) {
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
auto input_format = AnfAlgo::GetInputFormat(cnode, index);

View File

@ -28,8 +28,6 @@
namespace mindspore {
namespace opt {
namespace {
const size_t kCastInputNum = 2;
const size_t kTupleGetitemInputNum = 3;
bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx,
const std::shared_ptr<kernel::KernelBuildInfo> &candidate_kernel_info) {
if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) {
@ -126,7 +124,8 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
std::vector<Shape> shapes;
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t index = 0; index < output_num; ++index) {
if (cast_index == index) {
shapes.emplace_back(cast_shape);
types.emplace_back(cast_dtype);
@ -175,7 +174,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*alternative_kernel_info)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
if (node->inputs().size() < kCastInputNum) {
if (AnfAlgo::GetInputTensorNum(node) < kCastInputTensorNum) {
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
}
return node->input(1);
@ -188,9 +187,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
*prior_op = x_cnode;
// when x_node is tuple_getitem
if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) {
if (x_cnode->inputs().size() < kTupleGetitemInputNum) {
MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size();
}
CheckCNodeInputSize(x_cnode, kTupleGetItemInputTensorNum);
MS_EXCEPTION_IF_NULL(output_idx);
AnfNodePtr input1 = x_cnode->input(1);
MS_EXCEPTION_IF_NULL(input1);
@ -214,9 +211,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) {
MS_EXCEPTION_IF_NULL(cur_node);
MS_EXCEPTION_IF_NULL(kernel_query);
if (cur_node->inputs().size() < kCastInputNum) {
MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:";
}
CheckCNodeInputSize(cur_node, kCastInputTensorNum);
AnfNodePtr x_node = cur_node->input(1);
if (IsUsedByOthers(graph, x_node)) {
return nullptr;

View File

@ -69,7 +69,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kTransOpInputNum);
CheckCNodeInputSize(cnode, kTransOpInputTensorNum);
auto input_node = cnode->input(1);
if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) {
kernel_graph->ReplaceInternalOutput(node, input_node);

View File

@ -111,8 +111,8 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) {
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is "
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
}
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3],

View File

@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
const auto &bn_grad_inputs = bn_grad_node->inputs();
if (bn_grad_inputs.size() < kBNGradInputNum) {
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size."
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
}
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
std::vector<AnfNodePtr> bn_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
bn_grad_inputs[4], bn_grad_inputs[5]};
@ -60,10 +57,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(bn_grad_node);
MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs);
const auto &bn_grad_inputs = bn_grad_node->inputs();
if (bn_grad_inputs.size() < kBNGradInputNum) {
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
}
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"
<< " trace: " << trace::DumpSourceLines(bn_grad_node);

View File

@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs();
if (bn_grad_inputs.size() != kBNGradInputNum) {
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
}
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
std::vector<AnfNodePtr> bn_update_grad_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
bn_grad_inputs[4], bn_grad_inputs[5]};
@ -59,10 +56,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_grad_node);
auto bn_grad_inputs = bn_grad_node->inputs();
if (bn_grad_inputs.size() != kBNGradInputNum) {
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
}
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size";
}

View File

@ -32,8 +32,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() != kBnInputNum) {
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString();
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
return false;
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
@ -64,10 +64,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() != kBnInputNum) {
MS_LOG(EXCEPTION) << "BN node has wrong input size"
<< " trace: " << trace::DumpSourceLines(bn_cnode);
}
CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"
<< " trace: " << trace::DumpSourceLines(bn_cnode);
@ -102,8 +99,8 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kBnInputNum) {
MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
return nullptr;
}
// Create BNTrainingReduce node and get outputs of BNTrainingReduce

View File

@ -123,8 +123,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kGatherV2DynInputNum + 1) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputNum
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}

View File

@ -28,11 +28,7 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
std::vector<AnfNodePtr> *square_sum_all_outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(lars_v2);
if (lars_v2->size() != kLarsV2InputNum) {
MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum
<< " trace: " << trace::DumpSourceLines(lars_v2);
}
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1),
lars_v2->input(2)};
auto square_sum_all = graph->NewCNode(inputs);
@ -55,10 +51,7 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"
<< " trace: " << trace::DumpSourceLines(lars_v2);
}
if (lars_v2->size() != kLarsV2InputNum) {
MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum
<< " trace: " << trace::DumpSourceLines(lars_v2);
}
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)),
lars_v2->input(1),
lars_v2->input(2),

View File

@ -91,7 +91,7 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().size() != kLayerNormGradInputNum) {
if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormGradInputTensorNum) {
return nullptr;
}

View File

@ -110,7 +110,7 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, 2);
CheckCNodeInputSize(cnode, 1);
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0);
auto prim = AnfAlgo::GetCNodePrimitive(cnode);

View File

@ -76,8 +76,8 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) {
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is "
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
}
bn_training_update_v3->set_abstract(bn->abstract());

View File

@ -34,10 +34,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
MS_EXCEPTION_IF_NULL(origin_cnode);
if (origin_cnode->inputs().size() < kSplitInputNum) {
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
<< kSplitInputNum - 1 << " trace: " << trace::DumpSourceLines(origin_cnode);
}
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}

View File

@ -32,10 +32,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
MS_EXCEPTION_IF_NULL(origin_cnode);
if (origin_cnode->inputs().size() < kSplitInputNum) {
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
<< kSplitInputNum - 1;
}
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
return CreateSplitVNode(func_graph, origin_cnode->input(1));
}

View File

@ -146,7 +146,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
CheckCNodeInputSize(new_cnode, kTopkInputNum);
CheckCNodeInputSize(new_cnode, kTopkInputTensorNum);
// Convert the tensor input to scalar and convert it to attr
auto input_k = new_cnode->input(kTopkIndexK + 1);
MS_EXCEPTION_IF_NULL(input_k);

View File

@ -31,7 +31,7 @@ const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const A
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
CheckCNodeInputSize(node->cast<CNodePtr>(), kTransOpInputTensorNum);
if (IsFormatInvaild(node)) {
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
return DoSplit(func_graph, node);

View File

@ -77,8 +77,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}

View File

@ -62,8 +62,8 @@ bool CheckIndex(const AnfNodePtr &index_node) {
bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(batchnorm);
if (batchnorm->size() < kBatchNormInputNum + 1) {
MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum;
if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) {
MS_LOG(DEBUG) << "BatchNorm's input less than " << kBnInputTensorNum;
return false;
}
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) {
@ -87,7 +87,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
MS_EXCEPTION_IF_NULL(node);
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
if (!CheckIndex(index_node)) {

View File

@ -61,8 +61,8 @@ bool CheckIndex(const AnfNodePtr &index_node) {
bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(batchnormgrad);
if (batchnormgrad->size() < kBatchNormInputNum + 1) {
MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum;
if (AnfAlgo::GetInputTensorNum(batchnormgrad) < kBNGradInputTensorNum) {
MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBnInputTensorNum;
return false;
}
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) {
@ -86,7 +86,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
MS_EXCEPTION_IF_NULL(node);
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize);
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
if (!CheckIndex(index_node)) {

View File

@ -79,7 +79,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
return nullptr;
}
MS_EXCEPTION_IF_NULL(minimum);
if (minimum->inputs().size() != kMinimumInputNum) {
if (AnfAlgo::GetInputTensorNum(minimum) != kMinimumInputTensorNum) {
return nullptr;
}

View File

@ -30,9 +30,7 @@ const size_t kReluV2OutputNum = 2;
CNodePtr GetRelu(const CNodePtr &relu_grad) {
MS_EXCEPTION_IF_NULL(relu_grad);
if (relu_grad->size() != kReluGradInputNum) {
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
}
CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum);
auto relu_anf = relu_grad->input(2);
MS_EXCEPTION_IF_NULL(relu_anf);
return relu_anf->cast<CNodePtr>();
@ -41,9 +39,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu);
if (relu->size() != kReluInputNum) {
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
}
CheckCNodeInputSize(relu, kReluInputTensorNum);
auto prim = std::make_shared<Primitive>(kReluV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};

View File

@ -53,32 +53,9 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect
}
} // namespace
const BaseRef FusedBatchNormFusion::DefinePattern() const {
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
VarPtr index0 = std::make_shared<CondVar>(IsC);
VarPtr index1 = std::make_shared<CondVar>(IsC);
VarPtr index2 = std::make_shared<CondVar>(IsC);
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
}
ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(equiv);
auto iter_constant_input0 = (*equiv).find(constant_input0_var_);
if (iter_constant_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched.";
}
auto constant_input = utils::cast<AnfNodePtr>(iter_constant_input0->second);
auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_);
MS_EXCEPTION_IF_NULL(constant_input);
if (!constant_input->isa<ValueNode>()) {
return nullptr;
@ -113,31 +90,15 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
// Set input to create node
auto iter_data_input0 = (*equiv).find(data_input0_var_);
if (iter_data_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."
<< " trace: " << trace::DumpSourceLines(node);
}
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)),
utils::cast<AnfNodePtr>(iter_data_input0->second)};
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
MS_EXCEPTION_IF_NULL(bn_training_reduce);
bn_training_reduce->set_scope(node->scope());
// Set abstract
auto iter_data_input1 = (*equiv).find(data_input1_var_);
if (iter_data_input1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."
<< " trace: " << trace::DumpSourceLines(node);
}
auto data_input1 = utils::cast<AnfNodePtr>(iter_data_input1->second);
auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_);
MS_EXCEPTION_IF_NULL(data_input1);
auto iter_data_input2 = (*equiv).find(data_input2_var_);
if (iter_data_input2 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."
<< " trace: " << trace::DumpSourceLines(node);
}
auto data_input2 = utils::cast<AnfNodePtr>(iter_data_input2->second);
auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_);
MS_EXCEPTION_IF_NULL(data_input2);
AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
@ -150,39 +111,15 @@ void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv,
std::vector<AnfNodePtr> *bn_training_update_inputs) const {
MS_EXCEPTION_IF_NULL(equiv);
MS_EXCEPTION_IF_NULL(bn_training_update_inputs);
auto iter_data_input0 = (*equiv).find(data_input0_var_);
if (iter_data_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched.";
}
auto iter_data_input1 = (*equiv).find(data_input1_var_);
if (iter_data_input1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched.";
}
auto iter_data_input2 = (*equiv).find(data_input2_var_);
if (iter_data_input2 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched.";
}
auto iter_variable_input0 = (*equiv).find(variable_input0_var_);
if (iter_variable_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched.";
}
auto iter_variable_input1 = (*equiv).find(variable_input1_var_);
if (iter_variable_input1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched.";
}
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
<< ", but it is " << bn_training_reduce_outputs.size();
}
*bn_training_update_inputs = {
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)),
utils::cast<AnfNodePtr>(iter_data_input0->second),
utils::cast<AnfNodePtr>(GetAnfNodeByVar(equiv, data_input0_var_)),
bn_training_reduce_outputs[0],
bn_training_reduce_outputs[1],
utils::cast<AnfNodePtr>(iter_data_input1->second),
utils::cast<AnfNodePtr>(iter_data_input2->second),
utils::cast<AnfNodePtr>(iter_variable_input0->second),
utils::cast<AnfNodePtr>(iter_variable_input1->second),
GetAnfNodeByVar(equiv, data_input1_var_),
GetAnfNodeByVar(equiv, data_input2_var_),
GetAnfNodeByVar(equiv, variable_input0_var_),
GetAnfNodeByVar(equiv, variable_input1_var_),
};
}
@ -197,19 +134,9 @@ void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv
MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
}
auto iter_variable_input0 = (*equiv).find(variable_input0_var_);
if (iter_variable_input0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."
<< " trace: " << trace::DumpSourceLines(bn);
}
auto variable_input0 = utils::cast<AnfNodePtr>(iter_variable_input0->second);
auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_);
auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_);
MS_EXCEPTION_IF_NULL(variable_input0);
auto iter_variable_input1 = (*equiv).find(variable_input1_var_);
if (iter_variable_input1 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."
<< " trace: " << trace::DumpSourceLines(bn);
}
auto variable_input1 = utils::cast<AnfNodePtr>(iter_variable_input1->second);
MS_EXCEPTION_IF_NULL(variable_input1);
*abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(),
bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]};
@ -227,13 +154,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
MS_EXCEPTION_IF_NULL(bn_training_update);
// Set abstract
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
if (iter_batch_norm == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
<< " trace: " << trace::DumpSourceLines(node);
}
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
MS_EXCEPTION_IF_NULL(bn);
AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);
AbstractBasePtrList abstract_list;
GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list);
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
@ -249,6 +170,23 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
return bn_training_update;
}
void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_);
MS_EXCEPTION_IF_NULL(assign_sub1);
for (const auto &node_index : manager->node_users()[assign_sub1]) {
const AnfNodePtr &output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
(void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_));
break;
}
}
}
const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
@ -271,14 +209,8 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
<< bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node);
}
// Replace old bn outputs with new outputs
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
if (iter_batch_norm == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
<< " trace: " << trace::DumpSourceLines(node);
}
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
std::vector<AnfNodePtr> bn_outputs;
GetBNOutput(func_graph, bn, &bn_outputs);
GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (const auto &output : bn_outputs) {
@ -297,7 +229,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
(void)manager->Replace(output, bn_training_update_outputs[index]);
}
}
return bn_training_update_outputs[0];
(void)manager->Replace(node, bn_training_update_outputs[0]);
EliminateMonadNodes(func_graph, equiv);
return nullptr;
}
const BaseRef FusedBatchNormFusion::DefinePattern() const {
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
VarPtr index0 = std::make_shared<CondVar>(IsC);
VarPtr index1 = std::make_shared<CondVar>(IsC);
VarPtr index2 = std::make_shared<CondVar>(IsC);
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
}
const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
@ -317,8 +270,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
VectorRef cast2 = VectorRef({prim::kPrimCast, mul0});
VectorRef cast3 = VectorRef({prim::kPrimCast, mul1});
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2});
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3});
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_});
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_});
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
}
@ -340,8 +293,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
}

View File

@ -27,15 +27,20 @@ namespace opt {
class FusedBatchNormFusion : public PatternProcessPass {
public:
explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph),
data_input0_var_(std::make_shared<Var>()),
data_input1_var_(std::make_shared<Var>()),
data_input2_var_(std::make_shared<Var>()),
variable_input0_var_(std::make_shared<Var>()),
variable_input1_var_(std::make_shared<Var>()),
constant_input0_var_(std::make_shared<Var>()),
constant_input1_var_(std::make_shared<Var>()),
batch_norm_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()))) {}
: PatternProcessPass(name, multigraph) {
data_input0_var_ = std::make_shared<Var>();
data_input1_var_ = std::make_shared<Var>();
data_input2_var_ = std::make_shared<Var>();
variable_input0_var_ = std::make_shared<Var>();
variable_input1_var_ = std::make_shared<Var>();
constant_input0_var_ = std::make_shared<Var>();
constant_input1_var_ = std::make_shared<Var>();
batch_norm_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()));
assign_sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
assign_sub1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
monad0_var_ = std::make_shared<Var>();
monad1_var_ = std::make_shared<Var>();
}
~FusedBatchNormFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
@ -50,6 +55,7 @@ class FusedBatchNormFusion : public PatternProcessPass {
AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
ValuePtr GetFactor(const EquivPtr &equiv) const;
void EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
VarPtr data_input0_var_;
VarPtr data_input1_var_;
@ -59,6 +65,10 @@ class FusedBatchNormFusion : public PatternProcessPass {
VarPtr constant_input0_var_;
VarPtr constant_input1_var_;
VarPtr batch_norm_var_;
VarPtr assign_sub0_var_;
VarPtr assign_sub1_var_;
VarPtr monad0_var_;
VarPtr monad1_var_;
};
class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion {

View File

@ -30,33 +30,21 @@ std::tuple<AnfNodePtr, AnfNodePtr, AnfNodePtr, AnfNodePtr> GetSharedNodes(const
MS_EXCEPTION_IF_NULL(node);
auto add3 = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add3);
if (add3->inputs().size() < kAddInputNum) {
MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
CheckCNodeInputSize(add3, kAddInputTensorNum);
auto real_div2_anf = add3->input(1);
MS_EXCEPTION_IF_NULL(real_div2_anf);
auto real_div2 = real_div2_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div2);
if (real_div2->inputs().size() < kRealDivInputNum) {
MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
CheckCNodeInputSize(real_div2, kRealDivInputTensorNum);
auto sqrt0_anf = real_div2->input(2);
MS_EXCEPTION_IF_NULL(sqrt0_anf);
auto sqrt0 = sqrt0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt0);
if (sqrt0->inputs().size() < kRsqrtInputNum) {
MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
CheckCNodeInputSize(sqrt0, kSqrtInputTensorNum);
auto add2_anf = sqrt0->input(1);
MS_EXCEPTION_IF_NULL(add2_anf);
auto add2 = add2_anf->cast<CNodePtr>();
if (add2->inputs().size() < kAddInputNum) {
MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
CheckCNodeInputSize(add2, kAddInputTensorNum);
return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2));
}
@ -66,7 +54,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto add5 = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) {
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add5) != kAddInputTensorNum) {
return false;
}
auto real_div4_anf = add5->input(1);
@ -74,7 +62,8 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto real_div4 = real_div4_anf->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) {
if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName ||
AnfAlgo::GetInputTensorNum(real_div4) != kRealDivInputTensorNum) {
return false;
}
auto add4_anf = real_div4->input(2);
@ -82,7 +71,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto add4 = add4_anf->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) {
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add4) != kAddInputTensorNum) {
return false;
}
auto sqrt1_anf = add4->input(1);
@ -90,7 +79,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
return false;
}
auto sqrt1 = sqrt1_anf->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) {
if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || AnfAlgo::GetInputTensorNum(sqrt1) != kSqrtInputTensorNum) {
return false;
}
return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 &&
@ -104,14 +93,8 @@ std::tuple<AnfNodePtr, AnfNodePtr> GetAdd0Add1Nodes(const AnfNodePtr &real_div0_
auto real_div1 = real_div1_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div0);
MS_EXCEPTION_IF_NULL(real_div1);
if (real_div0->inputs().size() != kRealDivInputNum) {
MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"
<< " trace: " << trace::DumpSourceLines(real_div0_anf);
}
if (real_div1->inputs().size() != kRealDivInputNum) {
MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"
<< " trace: " << trace::DumpSourceLines(real_div1_anf);
}
CheckCNodeInputSize(real_div0, kRealDivInputTensorNum);
CheckCNodeInputSize(real_div1, kRealDivInputTensorNum);
return std::make_tuple(real_div0->input(1), real_div1->input(1));
}
} // namespace

View File

@ -77,9 +77,9 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode
MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr";
return false;
}
if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) {
if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormBetaGammaBackpropInputTensorNum) {
MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to "
<< kLayerNormBetaGammaBackpropInputNum;
<< kLayerNormBetaGammaBackpropInputTensorNum;
return false;
}
if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) {
@ -87,7 +87,8 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode
<< kLayerNormBetaGammaBackpropOutputNum;
return false;
}
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_num; ++i) {
if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) {
MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16";
return false;
@ -148,15 +149,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
// The cast_nodes size has been checked above.
MS_EXCEPTION_IF_NULL(cast_nodes[0]);
MS_EXCEPTION_IF_NULL(cast_nodes[1]);
if (cast_nodes[0]->inputs().size() != kCastInputNum) {
MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
CheckCNodeInputSize(cast_nodes[0], kCastInputTensorNum);
CheckCNodeInputSize(cast_nodes[1], kCastInputTensorNum);
(void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1));
if (cast_nodes[1]->inputs().size() != kCastInputNum) {
MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum
<< " trace: " << trace::DumpSourceLines(node);
}
(void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1));
return nullptr;
}

View File

@ -31,6 +31,20 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
if (matmul == nullptr || !matmul->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"
<< " trace: " << trace::DumpSourceLines(node);
}
// If there is a side-effect operator in the fusion, do not merge
MonadState state_matmul = GetMonadState(matmul);
MonadState state_node = GetMonadState(node, matmul);
if (!IsStateEquivalent(state_matmul, state_node)) {
return node;
}
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
@ -41,11 +55,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
new_node->set_scope(node->scope());
new_node->set_abstract(node->abstract());
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
if (matmul == nullptr || !matmul->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"
<< " trace: " << trace::DumpSourceLines(node);
}
AnfAlgo::CopyNodeAttrs(matmul, new_node);
return new_node;
}

View File

@ -43,7 +43,9 @@ const BaseRef MomentumLossscaleFusion::DefinePattern() const {
VarPtr X1 = std::make_shared<Var>();
VarPtr X2 = std::make_shared<Var>();
VarPtr X4 = std::make_shared<Var>();
return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4});
// UpdateState node
VarPtr X5 = std::make_shared<Var>();
return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4, X5});
}
const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -52,14 +54,15 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kApplyMomentumInputNum);
CheckCNodeInputSize(cnode, kApplyMomentumInputTensorNum);
AnfNodePtr mul = cnode->input(4);
MS_EXCEPTION_IF_NULL(mul);
auto mul_cnode = mul->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul_cnode);
CheckCNodeInputSize(mul_cnode, kMulInputNum);
CheckCNodeInputSize(mul_cnode, kMulInputTensorNum);
size_t value_node_index = 0;
for (size_t i = 1; i < kMulInputNum; ++i) {
// All real inputs include 1prim + x*TensorInput
for (size_t i = 1; i < kMulInputTensorNum + 1; ++i) {
if (CheckValueNodeInputOfMul(mul_cnode->input(i))) {
value_node_index = i;
break;
@ -70,12 +73,16 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
return nullptr;
}
auto new_prim = std::make_shared<Primitive>(kFusedMulApplyMomentumOpName);
auto depend_prim = NewValueNode(prim::kPrimDepend);
auto depend = func_graph->NewCNode({depend_prim, cnode->input(5), cnode->input(6)}); // depend on monad
depend->set_abstract(cnode->input(5)->abstract());
depend->set_scope(cnode->input(5)->scope());
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(new_prim),
cnode->input(1),
cnode->input(2),
cnode->input(3),
mul_cnode->input(kMulInputNum - value_node_index),
cnode->input(5),
mul_cnode->input(kMulInputTensorNum + 1 - value_node_index),
depend,
mul_cnode->input(value_node_index)};
auto new_node = func_graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_node);

View File

@ -67,7 +67,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
return nullptr;
}
auto add = node->cast<CNodePtr>();
if (add == nullptr || add->inputs().size() != kAddInputNum) {
if (add == nullptr || AnfAlgo::GetInputTensorNum(add) != kAddInputTensorNum) {
return nullptr;
}
CNodePtr mul = nullptr;

View File

@ -31,7 +31,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
MS_EXCEPTION_IF_NULL(addn);
auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
inputs.push_back(mul->input(kMulInputNum - lossscale_input_index));
inputs.push_back(mul->input(kMulInputTensorNum + 1 - lossscale_input_index));
inputs.push_back(addn->input(2));
// scalar input should be 3rd input
inputs.push_back(mul->input(lossscale_input_index));
@ -60,7 +60,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
}
auto addn = node->cast<CNodePtr>();
if (addn == nullptr || addn->inputs().size() != kAddNInputNum) {
if (addn == nullptr) {
return nullptr;
}
auto mul_anf = addn->input(1);
@ -68,7 +68,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
return nullptr;
}
auto mul = mul_anf->cast<CNodePtr>();
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
if (mul == nullptr || AnfAlgo::GetInputTensorNum(mul) != kMulInputTensorNum) {
return nullptr;
}
if (IsUsedByOthers(graph, mul)) {

View File

@ -98,7 +98,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "Skip trans op";
continue;
}
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; input_index++) {
std::vector<CNodePtr> trans_road;
bool first_flag = true;
auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road);

View File

@ -26,8 +26,10 @@ void DoRefresh(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "node is nullptr";
}
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index);
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; input_index++) {
auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (input_kernel_node->isa<Parameter>()) {
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();

View File

@ -34,13 +34,14 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(out_reshape);
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
if (IsUsedByOthers(func_graph, out_reshape)) {
return nullptr;
}
auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum);
auto in_reshape =
CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(in_reshape);
if (IsUsedByOthers(func_graph, in_reshape)) {
return nullptr;

View File

@ -46,9 +46,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(reshape_cnode);
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
return nullptr;

View File

@ -33,10 +33,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(square);
MS_EXCEPTION_IF_NULL(sum);
if (square->inputs().size() != kSquareNodeInputNum) {
MS_LOG(EXCEPTION) << "Square node has wrong input size"
<< " trace: " << trace::DumpSourceLines(square);
}
CheckCNodeInputSize(square, kSquareNodeInputTensorNum);
auto prim = std::make_shared<Primitive>(kSquareSumV1OpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> square_sumv1_inputs = {NewValueNode(prim), square->input(1)};
@ -60,10 +57,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(square);
MS_EXCEPTION_IF_NULL(sum);
if (square->inputs().size() != kSquareNodeInputNum) {
MS_LOG(EXCEPTION) << "Square node has wrong input size"
<< " trace: " << trace::DumpSourceLines(square);
}
CheckCNodeInputSize(square, kSquareNodeInputTensorNum);
auto prim = std::make_shared<Primitive>(kSquareSumV2OpName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> square_sumv2_inputs = {NewValueNode(prim), square->input(1)};
@ -84,10 +78,7 @@ std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node)
MS_EXCEPTION_IF_NULL(node);
auto sum = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sum);
if (sum->inputs().size() != kSumNodeInputNum) {
MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"
<< " trace: " << trace::DumpSourceLines(sum);
}
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
auto square_anf = sum->input(1);
MS_EXCEPTION_IF_NULL(square_anf);
auto square = square_anf->cast<CNodePtr>();

View File

@ -46,9 +46,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
return nullptr;

View File

@ -33,9 +33,9 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum);
auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputTensorNum);
MS_EXCEPTION_IF_NULL(transdata_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kTransOpInputTensorNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode);
auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode);

View File

@ -136,10 +136,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(conv2d);
if (conv2d->inputs().size() != kConvInputNum) {
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
<< conv2d->inputs().size() - 1;
}
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
conv2d->input(1), transpose};
auto depth_conv = graph->NewCNode(depth_conv_inputs);
@ -270,11 +267,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
if (!NeedUpdate(conv2d, input_shape, output_shape)) {
return nullptr;
}
if (conv2d->inputs().size() != kConvInputNum) {
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
<< conv2d->inputs().size() - 1;
}
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true);
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
SetConv2DAttrs(conv2d, depth_conv);

View File

@ -70,7 +70,8 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const {
VarPtr l1 = std::make_shared<Var>();
VarPtr l2 = std::make_shared<Var>();
VarPtr lr_power = std::make_shared<Var>();
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power});
VarPtr u = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u});
return pattern;
}
@ -84,7 +85,8 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const {
VarPtr lr = std::make_shared<Var>();
VarPtr grad = std::make_shared<Var>();
VarPtr momentum = std::make_shared<Var>();
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum});
VarPtr u = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u});
return pattern;
}
@ -114,7 +116,8 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
VarPtr rho = std::make_shared<Var>();
VarPtr momentum = std::make_shared<Var>();
VarPtr epsilon = std::make_shared<Var>();
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon});
VarPtr u = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u});
return pattern;
}

View File

@ -109,12 +109,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(one_hot_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
sparse_softmax_node->input(1), one_hot_node};
auto softmax_node = graph->NewCNode(inputs);
@ -162,10 +157,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(softmax_output_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
auto axis_value = GetAxis(softmax_output_node);
auto axis_node = GetAxisNode(softmax_output_node);
@ -200,9 +192,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(real_div_node);
if (real_div_node->size() != kRealDivInputNum) {
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
}
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
int64_t axis = -1;
auto axis_node = NewValueNode(axis);
@ -230,9 +220,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(real_div_node);
if (real_div_node->size() != kRealDivInputNum) {
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
}
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
int64_t axis = -1;
auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
std::vector<std::string> input_names = {"x"};
@ -257,13 +246,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(mul_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
if (mul_node->size() != kMulInputNum) {
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
std::vector<int64_t> multiple_value;
@ -310,12 +294,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
MS_EXCEPTION_IF_NULL(tile_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
std::vector<size_t> labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
if (labels_shape.size() != 1) {
MS_LOG(EXCEPTION) << "label's shape should be 1-D.";
@ -343,9 +322,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
MS_EXCEPTION_IF_NULL(depend_node);
if (depend_node->size() != kDependInputNum) {
MS_LOG(EXCEPTION) << "Op Depend's input not equal " << kDependInputNum;
}
CheckCNodeInputSize(depend_node, kDependInputTensorNum);
auto sparse_node = depend_node->input(index);
MS_EXCEPTION_IF_NULL(sparse_node);
return sparse_node->cast<CNodePtr>();
@ -353,9 +330,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
CNodePtr GetDependNode(const CNodePtr &mul_node) {
MS_EXCEPTION_IF_NULL(mul_node);
if (mul_node->size() != kMulInputNum) {
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
}
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto depend_node = mul_node->input(1);
MS_EXCEPTION_IF_NULL(depend_node);
return depend_node->cast<CNodePtr>();
@ -413,10 +388,7 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
auto sparse_softmax_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
return nullptr;
@ -451,17 +423,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
auto mul_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul_node);
if (mul_node->size() != kMulInputNum) {
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
}
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto depend_node = GetDependNode(mul_node);
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
@ -538,10 +505,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
auto sparse_softmax_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
return nullptr;
@ -573,17 +538,12 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
auto mul_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul_node);
if (mul_node->size() != kMulInputNum) {
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
}
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
auto sparse_softmax_node = mul_node->input(1);
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
}
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);

View File

@ -124,18 +124,16 @@ CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_si
MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() != input_size) {
auto op_name = AnfAlgo::GetCNodeName(cnode);
MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs.";
}
CheckCNodeInputSize(cnode, input_size);
return cnode;
}
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) {
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() != input_size) {
MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size;
auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
if (real_input_tensor_num != input_tensor_size) {
MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
<< "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size;
}
}
@ -149,17 +147,15 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
MS_EXCEPTION_IF_NULL(transop_cnode);
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1));
auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1);
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
auto transed_node = prev_transop_cnode->input(1);
MS_EXCEPTION_IF_NULL(transed_node);
std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
depend_cnode->input(kDependInputNum - 1)};
depend_cnode->input(kDependAttachNodeIndex)};
AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
MS_EXCEPTION_IF_NULL(replace_depend);
auto transed_abstract = transed_node->abstract();
@ -422,13 +418,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
continue;
}
output_node_list->push_back(output_info);
}
return output_node_list;
@ -537,6 +533,9 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
bool need_update = false;
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
input_node = AnfAlgo::VisitKernel(input_node, 0).first;
}
MS_EXCEPTION_IF_NULL(input_node);
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
@ -548,7 +547,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
primitive->set_attr(input_names_vec[i], value_node->value());
need_update = true;
} else {
new_inputs.push_back(input_node);
new_inputs.push_back(inputs[i + 1]);
}
}
if (need_update) {
@ -785,7 +784,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
for (size_t index = 0; index < output_num; ++index) {
types.push_back(kTypeUnknown);
}
kernel_build_info_builder->SetOutputsDeviceType(types);

View File

@ -29,36 +29,34 @@
namespace mindspore {
namespace opt {
constexpr size_t kTransOpInputNum = 2;
constexpr size_t kCastInputNum = 2;
constexpr size_t kDependInputNum = 3;
constexpr size_t kReluInputNum = 2;
constexpr size_t kReluGradInputNum = 3;
constexpr size_t kAddInputNum = 3;
constexpr size_t kAddNInputNum = 3;
constexpr size_t kTupleGetitemInputNum = 3;
constexpr size_t kConvInputNum = 3;
constexpr size_t kRealDivInputNum = 3;
constexpr size_t kSqrtInputNum = 2;
constexpr size_t kMulInputNum = 3;
constexpr size_t kRsqrtInputNum = 2;
constexpr size_t kSubInputNum = 3;
constexpr size_t kAssignSubInputNum = 3;
constexpr size_t kDropoutInputNum = 2;
constexpr size_t kTransOpInputTensorNum = 1;
constexpr size_t kCastInputTensorNum = 1;
constexpr size_t kDependInputTensorNum = 2;
constexpr size_t kReluInputTensorNum = 1;
constexpr size_t kReluGradInputTensorNum = 2;
constexpr size_t kAddInputTensorNum = 2;
constexpr size_t kTupleGetItemInputTensorNum = 2;
constexpr size_t kConvInputTensorNum = 2;
constexpr size_t kRealDivInputTensorNum = 2;
constexpr size_t kSqrtInputTensorNum = 1;
constexpr size_t kMatMulInputTensorNum = 2;
constexpr size_t kMulInputTensorNum = 2;
constexpr size_t kSubInputTensorNum = 2;
constexpr size_t kAssignSubInputTensorNum = 2;
constexpr size_t kDropoutInputTensorNum = 1;
constexpr size_t kAssignInputTensorNum = 2;
constexpr size_t kConvBn1OutputNum = 3;
constexpr size_t kBn2ReluOutputNum = 4;
constexpr size_t kBnInputNum = 6;
constexpr size_t kBnInputTensorNum = 5;
constexpr size_t kBnOutputNum = 5;
constexpr size_t kBatchNormInputNum = 5;
constexpr size_t kBatchNormOutputNum = 5;
constexpr size_t kBN1OutputNum = 2;
constexpr size_t kBN2OutputNum = 3;
constexpr size_t kBN3OutputNum = 1;
constexpr size_t kBNGradInputNum = 6;
constexpr size_t kBNGradInputTensorNum = 5;
constexpr size_t kBNGradOutputNum = 3;
constexpr size_t kBNGrad1OutputNum = 3;
@ -72,10 +70,10 @@ constexpr size_t kBNTrainingUpdateV3OutputNum = 5;
constexpr size_t kBNTrainingUpdateGradOutputNum = 2;
constexpr size_t kSingleOutputNum = 1;
constexpr size_t kSumNodeInputNum = 2;
constexpr size_t kSquareNodeInputNum = 2;
constexpr size_t kSumNodeInputTensorNum = 1;
constexpr size_t kSquareNodeInputTensorNum = 1;
constexpr size_t kSquareSumv2OutputNum = 2;
constexpr size_t kMinimumInputNum = 3;
constexpr size_t kMinimumInputTensorNum = 2;
constexpr size_t kLambNextMVWithDecayInputNum = 7;
constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5;
@ -85,26 +83,25 @@ constexpr size_t kLambNextRightOutputNum = 2;
constexpr size_t kLambUpdateWithLrV2InputNum = 8;
constexpr size_t kLambNextMVRuleInputNum = 14;
constexpr size_t kLambNextMVRuleOutputNum = 4;
constexpr size_t kBackendReshapeInputNum = 2;
constexpr size_t kBackendTransposeInputNum = 2;
constexpr size_t kBackendReshapeInputTensorNum = 1;
constexpr size_t kBackendTransposeInputTensorNum = 1;
constexpr size_t kAdamApplyOneWithDecayOutputNum = 3;
constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5;
constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4;
constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2;
constexpr size_t kLayerNormGradInputNum = 6;
constexpr size_t kLayerNormGradInputTensorNum = 5;
constexpr size_t kAdamApplyOneOutputNum = 3;
constexpr size_t kBackendTransDataInputNum = 2;
constexpr size_t kApplyMomentumInputNum = 6;
constexpr size_t kBiasAddInputNum = 3;
constexpr size_t kTopkInputNum = 3;
constexpr size_t kLarsV2InputNum = 5;
constexpr size_t kApplyMomentumInputTensorNum = 5;
constexpr size_t kBiasAddInputTensorNum = 2;
constexpr size_t kTopkInputTensorNum = 2;
constexpr size_t kLarsV2InputTensorNum = 4;
constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
constexpr size_t kSplitInputNum = 2;
constexpr size_t kGatherV2DynInputNum = 3;
constexpr size_t kUnsortedSegmentSumInputNum = 2;
constexpr size_t kSplitInputTensorNum = 1;
constexpr size_t kGatherV2DynInputTensorNum = 3;
constexpr size_t kUnsortedSegmentSumInputTensorNum = 2;
constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2;
constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3;
constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum = 2;
constexpr size_t kOneHotOutputNum = 1;
constexpr size_t kOneHotInputNum = 5;
constexpr size_t kOneHotInputTensorNum = 4;
enum FusedBatchNormInput {
kX = 1,
@ -137,7 +134,7 @@ bool Visited(const BaseRef &n);
// check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size);
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size);
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_num);
bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y);

View File

@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
@ -51,19 +53,30 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace
const BaseRef AdamFusion::DefinePattern() const {
VectorRef next_m = VectorRef(
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
VectorRef next_v =
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update =
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});
VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_m});
VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_v});
return next_param;
}
@ -81,6 +94,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]);
MS_EXCEPTION_IF_NULL(beta1_input);
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
MS_EXCEPTION_IF_NULL(beta2_input);
@ -91,13 +105,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
MS_EXCEPTION_IF_NULL(m_input);
MS_EXCEPTION_IF_NULL(v_input);
MS_EXCEPTION_IF_NULL(gradient_input);
MS_EXCEPTION_IF_NULL(u_input);
// Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators.
auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name());
MS_EXCEPTION_IF_NULL(prim_depend);
std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input};
auto param = graph->NewCNode(param_inputs);
MS_EXCEPTION_IF_NULL(param);
param->set_abstract(param_input->abstract());
// Fused into a FusedAdam operator.
auto prim = std::make_shared<Primitive>(kFusedAdamName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
eps_input, lr_input, param_input, m_input, v_input,
gradient_input};
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input};
auto adam = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
@ -107,6 +138,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
auto build_info = GenerateKernelBuildInfo(adam);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdam and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam);
break;
}
}
return adam;
}
} // namespace opt

View File

@ -34,6 +34,8 @@ class AdamFusion : public PatternProcessPass {
m_ = std::make_shared<Var>();
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
u_ = std::make_shared<Var>();
u2_ = std::make_shared<Var>();
}
~AdamFusion() override = default;
const BaseRef DefinePattern() const override;
@ -50,6 +52,8 @@ class AdamFusion : public PatternProcessPass {
VarPtr m_;
VarPtr v_;
VarPtr gradient_;
VarPtr u_;
VarPtr u2_;
};
} // namespace opt
} // namespace mindspore

View File

@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
@ -51,11 +53,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
} // namespace
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
VectorRef next_m = VectorRef(
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
VectorRef next_v =
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
VectorRef update =
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
@ -63,9 +68,16 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const {
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});
VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_m});
VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_v});
return next_param;
}
@ -85,6 +97,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]);
MS_EXCEPTION_IF_NULL(beta1_input);
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
MS_EXCEPTION_IF_NULL(beta2_input);
@ -96,13 +109,31 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
MS_EXCEPTION_IF_NULL(m_input);
MS_EXCEPTION_IF_NULL(v_input);
MS_EXCEPTION_IF_NULL(gradient_input);
MS_EXCEPTION_IF_NULL(u_input);
// Use depend(param, u) to maintain the execution order of FusedAdamWeightDecay and the previous operators.
auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name());
MS_EXCEPTION_IF_NULL(prim_depend);
std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input};
auto param = graph->NewCNode(param_inputs);
MS_EXCEPTION_IF_NULL(param);
param->set_abstract(param_input->abstract());
// Fused into a FusedAdamWeightDecay operator.
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
eps_input, lr_input, param_input, m_input, v_input,
gradient_input, weight_decay_input};
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
beta1_input,
one_sub_beta1_input,
beta2_input,
one_sub_beta2_input,
eps_input,
lr_input,
param,
m_input,
v_input,
gradient_input,
weight_decay_input};
auto adam_weight_decay = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(adam_weight_decay);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
@ -112,6 +143,30 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
// Replace the parameters of the last UpdateState to maintain
// the execution order of FusedAdamWeightDecay and the following operators.
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
auto n = node->cast<CNodePtr>()->input(2);
auto fg = n->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto mgr = fg->manager();
MS_EXCEPTION_IF_NULL(mgr);
auto &node_users = mgr->node_users();
auto iter = node_users.find(n);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
}
auto &users = iter->second;
for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
break;
}
}
return adam_weight_decay;
}
} // namespace opt

View File

@ -35,6 +35,8 @@ class AdamWeightDecayFusion : public PatternProcessPass {
m_ = std::make_shared<Var>();
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
u_ = std::make_shared<Var>();
u2_ = std::make_shared<Var>();
}
~AdamWeightDecayFusion() override = default;
const BaseRef DefinePattern() const override;
@ -52,6 +54,8 @@ class AdamWeightDecayFusion : public PatternProcessPass {
VarPtr m_;
VarPtr v_;
VarPtr gradient_;
VarPtr u_;
VarPtr u2_;
};
} // namespace opt
} // namespace mindspore

View File

@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}

View File

@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
@ -78,7 +80,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; i++) {
types.push_back(AnfAlgo::GetOutputInferDataType(node, i));
shapes.push_back(AnfAlgo::GetOutputInferShape(node, i));
}

View File

@ -51,7 +51,7 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
return apply_momentum;
}
@ -66,17 +66,19 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
MS_EXCEPTION_IF_NULL(monad_state);
auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation,
learning_rate, gradient, momentum};
learning_rate, gradient, momentum, monad_state};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

View File

@ -31,6 +31,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
monad_state_ = std::make_shared<Var>();
}
~ApplyMomentumScaleFusion() override = default;
const BaseRef DefinePattern() const override;
@ -45,6 +46,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
VarPtr monad_state_;
};
} // namespace opt
} // namespace mindspore

View File

@ -49,10 +49,11 @@ bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) {
}
const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const {
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
VectorRef weight_decay =
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_});
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), gradient_});
VectorRef apply_momentum = VectorRef(
{prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_, monad_state_});
return apply_momentum;
}
@ -67,17 +68,19 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
MS_EXCEPTION_IF_NULL(weight_decay);
MS_EXCEPTION_IF_NULL(variable);
MS_EXCEPTION_IF_NULL(accumulation);
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
MS_EXCEPTION_IF_NULL(monad_state);
auto prim = std::make_shared<Primitive>(kFusedWeightApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, variable, accumulation,
learning_rate, gradient, momentum};
learning_rate, gradient, momentum, monad_state};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

View File

@ -25,12 +25,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
public:
explicit ApplyMomentumWeightDecayFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_fusion", multigraph) {
monad_ = std::make_shared<Var>();
weight_decay_ = std::make_shared<Var>();
variable_ = std::make_shared<Var>();
accumulation_ = std::make_shared<Var>();
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
monad_state_ = std::make_shared<Var>();
}
~ApplyMomentumWeightDecayFusion() override = default;
const BaseRef DefinePattern() const override;
@ -39,12 +41,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
private:
static bool IsScalar(const BaseRef &n);
VarPtr monad_;
VarPtr weight_decay_;
VarPtr variable_;
VarPtr accumulation_;
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
VarPtr monad_state_;
};
} // namespace opt
} // namespace mindspore

View File

@ -49,11 +49,12 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
}
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
VectorRef weight = VectorRef(
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
{prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
VectorRef apply_momentum =
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
return apply_momentum;
}
@ -69,6 +70,8 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
MS_EXCEPTION_IF_NULL(weight_decay);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(variable);
@ -76,11 +79,12 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
MS_EXCEPTION_IF_NULL(learning_rate);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(momentum);
MS_EXCEPTION_IF_NULL(monad_state);
auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable,
accumulation, learning_rate, gradient, momentum};
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation,
learning_rate, gradient, momentum, monad_state};
auto replace_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(replace_node);
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};

View File

@ -25,6 +25,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
public:
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
monad_ = std::make_shared<Var>();
weight_decay_ = std::make_shared<Var>();
scale_ = std::make_shared<CondVar>(IsScalar);
variable_ = std::make_shared<Var>();
@ -32,6 +33,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
learning_rate_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
momentum_ = std::make_shared<Var>();
monad_state_ = std::make_shared<Var>();
}
~ApplyMomentumWeightDecayScaleFusion() override = default;
const BaseRef DefinePattern() const override;
@ -40,6 +42,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
private:
static bool IsScalar(const BaseRef &n);
VarPtr monad_;
VarPtr weight_decay_;
VarPtr scale_;
VarPtr variable_;
@ -47,6 +50,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
VarPtr learning_rate_;
VarPtr gradient_;
VarPtr momentum_;
VarPtr monad_state_;
};
} // namespace opt
} // namespace mindspore

View File

@ -37,11 +37,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
for (size_t idx = 0; idx < node_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(kOpFormat_DEFAULT);
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(kOpFormat_DEFAULT);
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
@ -57,16 +59,39 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
std::vector<AnfNodePtr> cast_32to16_list;
std::vector<AnfNodePtr> cast_16to32_list;
AnfNodePtr cast_32to16_load_monad = nullptr;
AnfNodePtr cast_16to32_load_monad = nullptr;
constexpr size_t second_input_index = 2;
for (auto &cast_node : node_list) {
// currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle.
if (cast_node != nullptr && cast_node->isa<CNode>() && AnfAlgo::GetCNodeName(cast_node) == "Cast" &&
(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0))->isa<Parameter>()) {
auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0);
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0);
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
cast_32to16_list.push_back(cast_node);
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
cast_16to32_list.push_back(cast_node);
// { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }}
if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
auto input0 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0);
if (input0->isa<Parameter>() || (IsPrimitiveCNode(input0, prim::kPrimLoad) &&
(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input0), 0))->isa<Parameter>())) {
auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0);
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0);
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
cast_32to16_list.push_back(cast_node);
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
if (cast_32to16_load_monad == nullptr) {
cast_32to16_load_monad = monad;
} else if (cast_32to16_load_monad != monad) {
return false;
}
}
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
cast_16to32_list.push_back(cast_node);
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
if (cast_16to32_load_monad == nullptr) {
cast_16to32_load_monad = monad;
} else if (cast_16to32_load_monad != monad) {
return false;
}
}
}
}
}
}

View File

@ -36,11 +36,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
for (size_t idx = 0; idx < node_list.size(); ++idx) {
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(kOpFormat_DEFAULT);
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(kOpFormat_DEFAULT);
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));

View File

@ -53,6 +53,8 @@ std::set<string> kSkipOpNames = {
std::map<string, uint32_t> kAggregatesOpNames = {
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}};
constexpr size_t inplace_node_size = 2;
template <typename T>
void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node);
@ -60,40 +62,103 @@ void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
primitive->AddAttr(key, MakeValue(value));
}
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) {
std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_node) {
if (inplace_node.size() != inplace_node_size) {
return {0, false};
}
auto first_node = inplace_node[0].node;
auto second_node = inplace_node[1].node;
if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName ||
AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) {
return {0, false};
}
auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
MS_EXCEPTION_IF_NULL(first_node_channel);
size_t first_channel = first_node_channel->cast<Int64ImmPtr>()->value();
auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
MS_EXCEPTION_IF_NULL(second_node_channel);
size_t second_channel = second_node_channel->cast<Int64ImmPtr>()->value();
size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
return {cover_index, true};
}
void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src);
AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get());
size_t output_num = AnfAlgo::GetOutputTensorNum(src);
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < output_num; i++) {
types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i));
shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i));
}
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get());
}
void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, const FuncGraphPtr &graph) {
if (inplace_node->size() == inplace_node_size) {
auto first_cnode = (*inplace_node)[0].node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(first_cnode);
auto first_node_input = first_cnode->input(1);
auto second_cnode = (*inplace_node)[1].node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(second_cnode);
auto second_node_input = second_cnode->input(1);
// if two inplace nodes have same input, will be have loop after insert depend
// so copy a new input for one of inplace node
if (first_node_input == second_node_input) {
auto cnode = first_node_input->cast<CNodePtr>();
auto new_input = graph->NewCNode(cnode->inputs());
new_input->set_abstract(first_node_input->abstract());
CopyKernelInfo(first_node_input, new_input);
auto new_inplace_node = graph->NewCNode({first_cnode->input(0), new_input, first_cnode->input(2)});
new_inplace_node->set_abstract(first_cnode->abstract());
CopyKernelInfo(first_cnode, new_inplace_node);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(first_cnode, new_inplace_node);
(*inplace_node)[0].node = new_inplace_node;
}
}
}
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
const FuncGraphPtr &graph) {
SetPrimAttr(aggregate_node.node, "aggregate", true);
SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
SetPrimAttr(skip_node, "skip", true);
static uint32_t group = 0;
auto [cover_index, order_required] = GetCoverIndex(*inplace_node);
if (order_required) {
CheckInplaceNodeInputs(inplace_node, graph);
}
for (size_t i = 0; i < inplace_node->size(); i++) {
auto algo = (i == 0) ? "cover" : "accumulation";
SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo);
SetPrimAttr((*inplace_node)[i].node, "inplace_group", group);
SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index);
auto algo = (i == cover_index) ? "cover" : "accumulation";
auto node = (*inplace_node)[i].node;
SetPrimAttr(node, "inplace_algo", algo);
SetPrimAttr(node, "inplace_group", group);
SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
// for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover
if (order_required && i != cover_index) {
auto acc_node = node;
auto cover_node = (*inplace_node)[cover_index].node;
auto acc_node_input = acc_node->cast<CNodePtr>()->input(1);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
acc_node_input, cover_node};
auto depend_node = graph->NewCNode(inputs);
depend_node->set_abstract(acc_node_input->abstract());
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(acc_node_input, depend_node);
}
}
group++;
}
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes,
const AnfNodePtr aggregate_node) {
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
inplace_nodes[0].node, inplace_nodes[1].node};
auto control_depend_node = graph->NewCNode(inputs1);
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
aggregate_node, control_depend_node};
auto depend_node = graph->NewCNode(inputs2);
auto users = GetRealNodeUsedList(graph, aggregate_node);
if (users->size() == 0) {
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString();
}
auto mount_node = users->at(0).first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mount_node);
mount_node->set_input(kFirstDataInputIndex, depend_node);
}
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
std::vector<AnfNodeIndex> *inplace) {
MS_EXCEPTION_IF_NULL(skip_node);
@ -117,7 +182,8 @@ bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeInde
auto cnode = (*skip_node)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_num; i++) {
auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i);
if (!inplace_node->isa<CNode>()) {
return false;
@ -187,9 +253,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
<< "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString()
<< std::endl;
// 2. Set Node attr
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
// 3. Set dependence for inplace nodes
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node);
SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph);
}
return true;

View File

@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
return relu;
}
const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(tensor_add);
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex);
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
if (shape.back() % kBNChannelMultipleFactor != 0) {
return nullptr;
}
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4);
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(scale);
MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var);
MS_EXCEPTION_IF_NULL(z);
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex);
for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i));
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i));
}
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
return tuple_get_item;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class PostBatchNormAddReluFusion : public PatternProcessPass {
public:
explicit PostBatchNormAddReluFusion(bool multigraph = true)
: PatternProcessPass("post_batch_norm_add_relu_fusion", multigraph) {
x_ = std::make_shared<Var>();
scale_ = std::make_shared<Var>();
bias_ = std::make_shared<Var>();
mean_ = std::make_shared<Var>();
var_ = std::make_shared<Var>();
index_ = std::make_shared<Var>();
z_ = std::make_shared<Var>();
}
~PostBatchNormAddReluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr x_;
VarPtr scale_;
VarPtr bias_;
VarPtr mean_;
VarPtr var_;
VarPtr index_;
VarPtr z_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_

View File

@ -32,9 +32,7 @@ const size_t kReluV2OutputNum = 2;
CNodePtr GetRelu(const CNodePtr &relu_grad) {
MS_EXCEPTION_IF_NULL(relu_grad);
if (relu_grad->size() != kReluGradInputNum) {
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
}
CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum);
auto relu_anf = relu_grad->input(2);
MS_EXCEPTION_IF_NULL(relu_anf);
return relu_anf->cast<CNodePtr>();
@ -47,11 +45,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
std::vector<TypeId> outputs_type;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
inputs_format.push_back(kOpFormat_DEFAULT);
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
outputs_format.push_back(kOpFormat_DEFAULT);
}
@ -65,9 +65,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(relu);
if (relu->size() != kReluInputNum) {
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
}
CheckCNodeInputSize(relu, kReluInputTensorNum);
auto prim = std::make_shared<Primitive>(kReluV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};
@ -106,7 +104,8 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad);
for (size_t i = 0; i < output_num; i++) {
types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i));
shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i));
}

View File

@ -305,52 +305,14 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo
user_cnode->set_input(index, depend_cnode);
}
AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) {
// Create control depend, first input is composite op, second is user
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
main_graph->AddNode(control_depend_cnode);
// Create depend node to hold new control depend node.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode};
auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(patron_node->abstract());
main_graph->AddNode(depend_cnode);
return depend_cnode;
}
std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) {
auto mng = main_graph->manager();
if (mng == nullptr) {
mng = Manage(main_graph, true);
main_graph->set_manager(mng);
}
AnfNodePtr patron_node;
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_cnode);
auto output_node = return_cnode->input(kFirstDataInputIndex);
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
auto output_cnode = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
patron_node = output_cnode->input(kFirstDataInputIndex);
} else {
patron_node = output_node;
}
auto &user_nodes = mng->node_users()[patron_node];
auto user = user_nodes.begin();
return std::make_tuple(patron_node, user->first, user->second);
}
void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user,
int index) {
auto patron_user_cnode = patron_user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(patron_user_cnode);
patron_user_cnode->set_input(index, patron_node);
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) {
// Insert update_state_node, need mount a monad node.
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node};
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
main_graph->AddNode(update_state_cnode);
return update_state_cnode;
}
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
@ -474,24 +436,21 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs
}
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
const AnfNodePtr &broadcast_to_node,
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
// 1. find users, change getitem index if needed.
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
for (const auto &[user_node, index] : reduce_user_nodes) {
// 2. set ac output as user's input.
// 3. Make sure modified composite node running first.
// * To not change the origin node's dependency relation, add ControlDepend and Depend node.
// * For Return node and output node, ControlDepend node will change the order of these two node, which will may
// main graph running failed. So only add Depend node to meet the need of execute order.
if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) {
AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index);
} else {
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, broadcast_to_node);
to_process_order_.emplace_back(composite_node, user_node);
}
// 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
// update_state_node, broadcat_node and load_node to keep order.
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
auto load_node = main_graph->NewCNode(load_inputs);
main_graph->AddNode(load_node);
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, load_node);
to_process_order_.emplace_back(composite_node, user_node);
}
}
@ -509,8 +468,11 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
// Note: if it's single output, this will increase total memory because of a fake out.
ProcessOriginCNode(origin_composite_node, broadcast_to_node, mng);
// Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user.
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng);
// Insert update_state_node to keep execution order.
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
// Replace origin ReduceSum's user with atomic clean output
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
}
@ -554,14 +516,6 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
}
if (changed) {
if (!to_process_order_.empty()) {
auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph);
for (const auto &[prior, behind] : to_process_order_) {
patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node);
}
PostprocessForLastPatron(patron_node, patron_user, user_index);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}

View File

@ -37,9 +37,10 @@ class AtomicCleanInsertter : public Pass {
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
const FuncGraphManagerPtr &mng);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index);
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node);
CNodePtr atomic_add_node_{nullptr};
private:
@ -48,11 +49,8 @@ class AtomicCleanInsertter : public Pass {
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph);
AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node);
void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index);
const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node,
const FuncGraphManagerPtr &mng);
std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
const AnfNodePtr &composite_node,
const FuncGraphManagerPtr &mng, bool correct_index);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -149,9 +149,18 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
}
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) {
if (fuse_nodes.empty()) {
continue;
}
if (fuse_nodes.size() == 1) {
// Do not fuse a single GraphKernel again.
// Do not fuse a single Assign.
if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) {
continue;
}
}
changed = true;
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
AnfNodePtr fused_new_node;

View File

@ -109,7 +109,10 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) {
// 1. Try to remove redundant depend.
bool changed = false;
auto nodes = TopoSort(func_graph->get_return());
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) {
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void {
if (HasAbstractMonad(node)) {
return;
}
if (RemoveRedundantDepend(node, mng)) {
changed = true;
}
@ -126,7 +129,8 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) {
// Find depend and its free nodes.
for (const auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
continue;
}

View File

@ -177,6 +177,7 @@ bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>();
pass->Run(func_graph);
}
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);

View File

@ -494,8 +494,8 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast, prim::kPrimExpandDims};
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimAssign, prim::kPrimExpandDims};
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif

View File

@ -629,7 +629,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
}
GetValidKernelNodes();
// call CostModel to get a split plan.
if (!SplitByCostModel() || split_plan_.size() != need_inline_.size()) {
if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
split_plan_.clear();
need_inline_.clear();
return;

View File

@ -103,28 +103,23 @@ bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr &param_user)
return result;
}
AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr &param_user) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user};
auto cd_node = func_graph->NewCNode(cd_inputs);
func_graph->AddNode(cd_node);
return cd_node;
}
void KeepExecOrder(const FuncGraphPtr &func_graph, const AnfNodePtr &gk_node, const AnfNodePtr &par_user_node,
const FuncGraphManagerPtr &mng) {
// Insert update_state_node, need mount a monad node.
auto u = NewValueNode(kUMonad);
u->set_abstract(kUMonad->ToAbstract());
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, gk_node};
auto update_state_node = func_graph->NewCNode(update_state_inputs);
update_state_node->set_abstract(gk_node->abstract());
func_graph->AddNode(update_state_node);
void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto output_tuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_tuple);
auto cur_node = output_tuple->input(1);
for (const auto &cd : cd_nodes) {
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd};
auto depend_node = func_graph->NewCNode(depend_inputs);
depend_node->set_abstract(depend_inputs[1]->abstract());
cur_node = depend_node;
}
mng->Replace(output_tuple->input(1), cur_node);
// Insert load_node
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), par_user_node, update_state_node};
auto load_node = func_graph->NewCNode(load_inputs);
load_node->set_abstract(par_user_node->abstract());
func_graph->AddNode(load_node);
mng->Replace(gk_node, par_user_node);
}
int64_t GetitemIndex(const AnfNodePtr &getitem) {
@ -133,11 +128,10 @@ int64_t GetitemIndex(const AnfNodePtr &getitem) {
return GetValue<int64_t>(value_ptr);
}
AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode,
const AnfNodePtr &assign_to, int64_t removed_index) {
void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, const AnfNodePtr &assign_to,
int64_t removed_index) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
AnfNodePtrList cd_nodes;
for (const auto &getitem_iter : mng->node_users()[cnode]) {
auto getitem = getitem_iter.first;
if (GetitemIndex(getitem) != removed_index) continue;
@ -152,13 +146,10 @@ AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const An
if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) {
continue;
}
// keep execution order: cnode -> getitem_user
auto cd_node = AddControlDepend(func_graph, getitem, getitem_user);
cd_nodes.push_back(cd_node);
KeepExecOrder(func_graph, cnode, getitem_user, mng);
}
break;
}
return cd_nodes;
}
bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
@ -166,7 +157,6 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
bool changed = false;
AnfNodePtrList control_depend_nodes;
for (const auto &n : todos) {
if (!AnfAlgo::IsGraphKernel(n)) continue;
auto cnode = n->cast<CNodePtr>();
@ -174,11 +164,9 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
if (replaceable_nodes.empty()) continue;
changed = true;
for (const auto &iter : replaceable_nodes) {
auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end());
UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
}
}
LinkControlDepends(func_graph, control_depend_nodes);
return changed;
}

View File

@ -97,7 +97,8 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
for (auto &[node, node_rel] : (*node_rels)) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
continue;
}
@ -118,96 +119,6 @@ void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels);
}
std::tuple<std::pair<AnfNodePtr, AnfNodePtr>, std::pair<AnfNodePtrList, AnfNodePtrList>> FindRelationOfControlDepend(
const AnfNodePtr &node, OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
auto cnode = node->cast<CNodePtr>();
auto prior_node = cnode->input(kControlDependPriorIndex);
auto behind_node = cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(behind_node);
OrderedSet<AnfNodePtr> prior_nodes;
prior_nodes.insert(prior_node);
OrderedSet<AnfNodePtr> behind_nodes;
behind_nodes.insert(behind_node);
int64_t depend_mode = 0;
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
}
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = (*node_rels)[prior_node].nexts;
}
if (behind_node->isa<Parameter>()) {
behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet<AnfNodePtr>();
}
// Get real nodes.
AnfNodePtrList real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
AnfNodePtrList real_behind_nodes;
std::set<AnfNodePtr> behind_visited;
for (const auto &tmp : behind_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited);
}
return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes));
}
void ReLinkNodesOfControlDependByRelation(const std::unordered_map<AnfNodePtr, AnfNodePtrList> &control_depend_info,
OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
// Relink and its log.
for (const auto &m : control_depend_info) {
const auto &prior = m.second[0];
const auto &behind = m.second[1];
(*node_rels)[prior].nexts.insert(behind);
(*node_rels)[behind].pres.insert(prior);
MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope()
<< " -> " << behind->fullname_with_scope();
}
}
void ProcessControlDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
std::unordered_map<AnfNodePtr, AnfNodePtrList> control_depend_info;
AnfNodePtrList latter_to_be_erased;
// Collect ControlDepend node and its input and output nodes.
for (auto &[node, node_rel] : (*node_rels)) {
if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
continue;
}
auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels);
auto &[prior_node, behind_node] = direct_relation;
auto &[real_prior_nodes, real_behind_nodes] = real_relations;
(*node_rels)[prior_node].nexts.erase(node);
(*node_rels)[behind_node].nexts.erase(node);
node_rel.pres.erase(prior_node);
node_rel.pres.erase(behind_node);
for (auto &first_node : real_prior_nodes) {
for (auto &second_node : real_behind_nodes) {
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
control_depend_info.insert({node, {first_node, second_node}});
}
}
latter_to_be_erased.push_back(node);
}
// Delete ControlDepend node before relink its relation.
for (const auto &node : latter_to_be_erased) {
node_rels->erase(node);
}
// Rebuild relation between prior and behind node.
ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels);
}
void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
AnfNodePtrList latter_to_be_erased;
for (auto &[node, node_rel] : (*node_rels)) {
@ -538,7 +449,6 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
}
ProcessDependCNode(&node_rels);
ProcessControlDependCNode(&node_rels);
ProcessThroughPassCNode(
[](const AnfNodePtr &node) {
return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/split_assign.h"
#include <vector>
#include <string>
#include <algorithm>
#include <memory>
#include "base/core_ops.h"
#include "utils/utils.h"
#include "runtime/device/kernel_info.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef SplitAssign::DefinePattern() const {
VarPtr Xs = std::make_shared<Var>();
VarPtr Us = std::make_shared<Var>();
VarPtr UMonad = std::make_shared<Var>();
return VectorRef({prim::kPrimAssign, Xs, Us, UMonad});
}
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
// Get original assign op's abstract and inputs
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
auto original_inputs = cnode->inputs();
// Create depend node
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]};
auto depend_cnode = func_graph->NewCNode(depend_inputs);
depend_cnode->set_abstract(original_inputs[1]->abstract());
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
// Create new assign node, delete U from inputs.
AnfNodePtrList new_assign_inputs = {NewValueNode(prim::kPrimAssign), depend_cnode, original_inputs[2]};
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
new_assign_cnode->set_abstract(original_abstract);
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());
return new_assign_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SplitAssign : public PatternProcessPass {
public:
explicit SplitAssign(bool multigraph = true) : PatternProcessPass("split_assign", multigraph) {}
~SplitAssign() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_

View File

@ -41,13 +41,15 @@ const BaseRef SubstituteDropout::DefinePattern() const {
void SetNewKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
@ -69,15 +71,13 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
MS_EXCEPTION_IF_NULL(node);
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kDropoutInputNum) {
MS_LOG(EXCEPTION) << "Dropout's input num is wrong";
}
CheckCNodeInputSize(cnode, kDropoutInputTensorNum);
AbstractBasePtr old_abstract = cnode->abstract()->Clone();
auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
ShapeVector shape_i64;
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
// The primitive should use a clone, otherwise the attr seed will be overrode.
// The primitive should use a clone, otherwise the attr seed will be overridden.
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())};
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
static_cast<void *>(&shape[0]), kNumberTypeInt64);

View File

@ -249,7 +249,8 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) {
if (node == nullptr) {
MS_LOG(EXCEPTION) << "The node pointer is a nullptr.";
}
if (node->isa<CNode>()) {
// Get ref count for cnode, except monad cnode.
if (node->isa<CNode>() && !HasAbstractMonad(node)) {
auto ak_node = node->cast<CNodePtr>();
auto key = ak_node.get();
MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx));
@ -314,7 +315,8 @@ void MemReuseUtil::SetKernelDefInputs() {
MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init.";
}
auto kernel_def = iter->second;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 0; i < input_num; ++i) {
auto ref_ptr = GetKernelInputRef(kernel, i);
if (ref_ptr != nullptr) {
// set the inputs of this kernel_def

View File

@ -214,7 +214,8 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph
// set real graph output node to be special who's refcount equal kMaxRefCount
for (const auto &output : graph->outputs()) {
MS_EXCEPTION_IF_NULL(output);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(output);
for (size_t i = 0; i < input_num; ++i) {
if (output->isa<CNode>()) {
auto cnode = output->cast<CNodePtr>();
auto input_node = cnode->input(i + 1);
@ -364,7 +365,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
const auto &cnodes = graph->execution_order();
for (const auto &node : cnodes) {
std::vector<const void *> curr_ous;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < output_num; ++i) {
auto it = AnfAlgo::GetOutputAddr(node, i);
MS_EXCEPTION_IF_NULL(it);
auto ptr = it->GetPtr();
@ -374,7 +376,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
}
(void)node_ous_.insert(std::make_pair(node.get(), curr_ous));
std::vector<const void *> curr_ins;
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_num; ++i) {
if (i + 1 >= node->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index: " << i
<< " is larger than input number: " << AnfAlgo::GetInputTensorNum(node);

View File

@ -37,7 +37,8 @@ bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, s
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx};
ordered_tensors_.push_back(tensor_info);
}

View File

@ -51,12 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
}
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
}
for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) {
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
@ -170,6 +172,117 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
return CheckSegments(segments, communication_op_node_size, segment_index);
}
// Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
// cycle after AllReduce fused. It's a workaround.
// case 1:
// cnode_load = Load(%para2, cnode_u)
// %100 = UpdateState(cnode_u, cnode_load)
// ...
// %109 = AssignAdd(%para485, Tensor(34), %100)
// %110 = UpdateState(%100, xxx)
// will convert to:
// cnode_load = Load(%para2, U)
// ...
// %109 = AssignAdd(%para485, Tensor(34), cnode_u)
// %110 = UpdateState(cnode_u, xxx)
//
// case 2:
// cnode_load = Load(%para2, cnode_u)
// %99 = make_tuple(yyy, ..., cnode_load, ...)
// %100 = UpdateState(cnode_u, %99)
// ...
// %109 = AssignAdd(%para485, Tensor(34), %100)
// %110 = UpdateState(%100, xxx)
// will convert to:
// cnode_load = Load(%para2, U)
// %99 = make_tuple(yyy, ...)
// %100 = UpdateState(cnode_u, %99)
// ...
// %109 = AssignAdd(%para485, Tensor(34), %100)
// %110 = UpdateState(%100, xxx)
//
// case 3:
// cnode_load = Load(%para2, cnode_u)
// %99 = make_tuple(cnode_load)
// %100 = UpdateState(cnode_u, %99)
// ...
// %109 = AssignAdd(%para485, Tensor(34), %100)
// %110 = UpdateState(%100, xxx)
// will convert to:
// cnode_load = Load(%para2, U)
// ...
// %109 = AssignAdd(%para485, Tensor(34), cnode_u)
// %110 = UpdateState(cnode_u, xxx)
static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) {
if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
return false;
}
if (search_cnode->inputs().size() != 3) {
MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
}
return search_cnode->input(2)->isa<CNode>();
});
if (cnode_load != nullptr) {
const auto &const_u_monad = NewValueNode(kUMonad);
const auto &cnode_u = cnode_load->input(2);
MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
MS_EXCEPTION_IF_NULL(cnode->func_graph());
MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
auto manager = cnode->func_graph()->manager();
manager->SetEdge(cnode_load, 2, const_u_monad);
// Update the u_monad input of UpdateState from CNode U same as Load to constant U.
CNodePtr cnode_update_state = nullptr;
CNodePtr cnode_make_tuple = nullptr;
const auto &cnode_load_users = manager->node_users()[cnode_load];
for (auto &load_user : cnode_load_users) {
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
for (auto &make_tuple_user : cnode_make_tuple_users) {
if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
if (cnode_user->input(1) == cnode_u) {
cnode_update_state = cnode_user;
cnode_make_tuple = load_user.first->cast<CNodePtr>();
break;
}
}
}
if (cnode_update_state != nullptr) {
break;
}
}
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
const auto &cnode_user = load_user.first->cast<CNodePtr>();
if (cnode_user->input(1) == cnode_u) {
cnode_update_state = cnode_user;
break;
}
}
}
if (cnode_update_state != nullptr) {
if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == 2) {
// case 1 and case 3: Replace cnode_update_state to cnode_u;
MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
<< " ::TO:: " << cnode_u->DebugString();
manager->Replace(cnode_update_state, cnode_u);
} else if (cnode_make_tuple->inputs().size() > 2) {
// case 2: remove cnode_load from cnode_make_tuple;
MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
const auto &make_tuple_inputs = cnode_make_tuple->inputs();
AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
[cnode_load](const auto &inp) { return inp != cnode_load; });
auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
} else {
MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
<< " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
}
}
}
}
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
const CommunicationOpInfo &communication_op_info,
size_t start_index, size_t end_index) const {
@ -184,6 +297,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
for (size_t idx = start_index; idx <= end_index; ++idx) {
auto cnode = communication_op_info.communication_op_nodes[idx];
MS_EXCEPTION_IF_NULL(cnode);
if (idx != start_index) {
AdjustAllReduceInputWithLoad(cnode);
}
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
}
CheckInputs(fusion_inputs);

View File

@ -107,9 +107,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
kernel::GetValidKernelNodes(sub_graph, &todo);
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
for (auto &t : todo) {
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());

View File

@ -37,7 +37,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
std::vector<AnfNodePtr> plant_inputs;
std::vector<int64_t> dyn_input_sizes;
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode_ptr);
for (size_t i = 0; i < input_num; ++i) {
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
@ -45,7 +46,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
dyn_input_sizes.push_back(input_size);
auto make_tuple = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t j = 0; j < tuple_input_num; ++j) {
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {

View File

@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
return nullptr;
}
}
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
return nullptr;
}
bool cnode_input_changed = false;

Some files were not shown because too many files have changed in this diff Show More