forked from mindspore-Ecosystem/mindspore
!12241 [auto-monad] Support side-effects by auto-monad
From: @hwhewei Reviewed-by: @zhunaipan,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
a063d7633d
|
@ -1,4 +1,4 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
cmake_minimum_required(VERSION 3.14.0)
|
||||
project(MindSpore)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
|
||||
|
@ -14,18 +14,25 @@ endif()
|
|||
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(CMAKE_OSX_SYSROOT "")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings \
|
||||
-Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \
|
||||
-Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \
|
||||
-Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \
|
||||
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||
endif()
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
add_compile_definitions(ENABLE_PYTHON)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \
|
||||
-Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \
|
||||
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 \
|
||||
-Werror -Wall -Wno-deprecated-declarations -fPIC")
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(PYBIND11_CPP_STANDARD -std=c++17)
|
||||
|
|
|
@ -132,6 +132,16 @@ def Depend(value, expr):
|
|||
return value
|
||||
|
||||
|
||||
def UpdateState(monad, expr):
|
||||
"""Implement `UpdateState`."""
|
||||
return monad
|
||||
|
||||
|
||||
def Load(value, u=None):
|
||||
"""Implement `Load`."""
|
||||
return value
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def make_ref(key, value, ref):
|
||||
return value
|
||||
|
|
|
@ -42,14 +42,16 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
|
|||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
}
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
|
|
|
@ -139,9 +139,9 @@ bool CheckCache(const std::string &kernel_name) {
|
|||
std::string kernel_json = bin_map->Search(kernel_name);
|
||||
bool ret = (!kernel_json.empty());
|
||||
if (ret) {
|
||||
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
|
||||
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
|
||||
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -730,30 +730,6 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann:
|
|||
return false;
|
||||
}
|
||||
|
||||
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node_list);
|
||||
auto output = func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::IsRealKernel(output)) {
|
||||
// single output.
|
||||
node_list->push_back(std::make_pair(output, 0));
|
||||
return;
|
||||
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
// multi output.
|
||||
auto &inputs = output_cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
node_list->push_back(in_with_idx);
|
||||
}
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
|
||||
<< " of graph: " << func_graph->ToString();
|
||||
}
|
||||
|
||||
bool IsWeightBoundary(const AnfNodePtr &node) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
return true;
|
||||
|
@ -776,7 +752,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto axis_attr = primitive->GetAttr(kAxis);
|
||||
if (axis_attr == nullptr) {
|
||||
MS_LOG(ERROR) << "This node does't have axie attr.";
|
||||
MS_LOG(ERROR) << "This node doesn't have axie attr.";
|
||||
return std::vector<int64_t>();
|
||||
}
|
||||
std::vector<int64_t> axis_list;
|
||||
|
|
|
@ -181,7 +181,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
std::vector<size_t> out_shape;
|
||||
out_shape.emplace_back(miss_count);
|
||||
std::vector<TypeId> dtypes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
|
||||
|
|
|
@ -69,7 +69,8 @@ void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
std::vector<size_t> out_shape;
|
||||
out_shape.emplace_back(count);
|
||||
std::vector<TypeId> dtypes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape, out_shape}, node_.get());
|
||||
|
|
|
@ -29,5 +29,8 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
AssignGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
AssignGpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -63,13 +63,15 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
for (const auto &type : kHcclSupportTypes) {
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
|
||||
inputs_type.push_back(type);
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) {
|
||||
outputs_format.emplace_back(GetKernelFormat(kernel_node, 0));
|
||||
} else {
|
||||
|
|
|
@ -31,7 +31,8 @@ bool IsPyNativeMode() {
|
|||
bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
|
||||
hccl_kernel_intput_shape_list->emplace_back(shape_i);
|
||||
}
|
||||
|
@ -42,7 +43,8 @@ bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<siz
|
|||
bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list);
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
|
||||
hccl_kernel_output_shape_list->emplace_back(shape_i);
|
||||
}
|
||||
|
@ -53,11 +55,12 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
|
|||
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(data_type_list);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i);
|
||||
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr);
|
||||
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
|
||||
MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr;
|
||||
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr;
|
||||
}
|
||||
data_type_list->emplace_back(iter->second);
|
||||
}
|
||||
|
|
|
@ -37,13 +37,15 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
|
||||
std::vector<std::string> inputs_format{};
|
||||
std::vector<TypeId> inputs_type{};
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
|
|||
|
||||
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
|
||||
if (output_index >= outputs_format_.size()) {
|
||||
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
|
||||
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
|
||||
return kInvalidFormat;
|
||||
}
|
||||
return outputs_format_[output_index];
|
||||
|
|
|
@ -86,6 +86,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel
|
|||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
// LabelSwitch always return UMonad.
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
|
||||
label_switch_build_info.emplace_back(builder.Build());
|
||||
}
|
||||
return label_switch_build_info;
|
||||
|
|
|
@ -74,11 +74,10 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
|
|||
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
|
||||
}
|
||||
kernel_build_info_builder->SetInputsDeviceType(input_types);
|
||||
// set output info
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
|
||||
// set ohter info
|
||||
// Kernel ops in while-list such as 'LabelSet' always return UMonad.
|
||||
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
|
||||
// set other info
|
||||
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
|
|
|
@ -1052,10 +1052,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
|
|||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
|
||||
auto node_inputs_size = cnode->inputs().size();
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (HasAbstractMonad(input)) {
|
||||
node_inputs_size--;
|
||||
}
|
||||
}
|
||||
if (op_info->inputs_ptr().size() < (node_inputs_size - 1)) {
|
||||
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
|
||||
}
|
||||
return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
|
||||
return (op_info->inputs_ptr().size() + 1 - node_inputs_size);
|
||||
}
|
||||
|
||||
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
|
||||
|
@ -1103,6 +1109,9 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
|
|||
bool is_dynamic_input = IsDynamicInput(cnode);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
if (HasAbstractMonad(input)) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_node = kernel_idx.first;
|
||||
size_t real_idx = kernel_idx.second;
|
||||
|
|
|
@ -112,6 +112,10 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
|
|||
const KernelSelectPtr &kernel_select) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_node = AnfAlgo::GetInputNode(node, index);
|
||||
if (HasAbstractMonad(input_node)) {
|
||||
// No transfer for monad inputs.
|
||||
return input_node;
|
||||
}
|
||||
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(node_with_index.first);
|
||||
auto real_input = node_with_index.first;
|
||||
|
@ -330,8 +334,9 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
|
||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||
// Monad inputs keep unchanged from GetTransInputNodePtr().
|
||||
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
new_inputs.push_back(input_node);
|
||||
|
@ -352,12 +357,18 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
|
||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
if (HasAbstractMonad(cur_input)) {
|
||||
// No cast for monad inputs.
|
||||
new_inputs.push_back(cur_input);
|
||||
continue;
|
||||
}
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
|
||||
auto real_input_node = kernel_with_index.first;
|
||||
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
|
|
|
@ -244,7 +244,9 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
|
|||
if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(),
|
||||
(*buffer_fusion_infos)[fusion_id].inputs_list.end(),
|
||||
in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) {
|
||||
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in);
|
||||
if (!HasAbstractMonad(in)) {
|
||||
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
|
|||
return real_node->isa<ValueNode>();
|
||||
}
|
||||
|
||||
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
||||
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
||||
MS_EXCEPTION_IF_NULL(control_depend);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
|
||||
make_tuple_inputs.emplace_back(hccl_node);
|
||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
control_depend->set_input(IntToSize(index), make_tuple);
|
||||
}
|
||||
|
||||
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
||||
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(tuple_getitem);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager"
|
||||
<< " trace: " << trace::DumpSourceLines(hccl_node);
|
||||
}
|
||||
for (const auto &node_index : iter->second) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
|
||||
const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
auto iter = node_users.find(hccl_node);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager"
|
||||
<< " trace: " << trace::DumpSourceLines(hccl_node);
|
||||
}
|
||||
// find hccl_node's output which is a control depend
|
||||
for (const auto &node_index : iter->second) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
|
||||
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
|
||||
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
|
||||
if (node_users.size() == 1) {
|
||||
|
@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|||
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(hccl_node);
|
||||
std::vector<AnfNodePtr> memcpy_async_list;
|
||||
bool need_memcpy_async = false;
|
||||
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
||||
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
||||
auto input = hccl_node->input(i);
|
||||
|
@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|||
if (memcpy_async == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
|
||||
}
|
||||
if (AnfAlgo::IsNodeDynamicShape(input)) {
|
||||
if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
|
||||
}
|
||||
new_inputs.push_back(memcpy_async);
|
||||
memcpy_async_list.push_back(memcpy_async);
|
||||
need_memcpy_async = true;
|
||||
} else {
|
||||
new_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
if (!memcpy_async_list.empty()) {
|
||||
if (need_memcpy_async) {
|
||||
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
||||
new_hccl_node->set_inputs(new_inputs);
|
||||
auto manager = graph->manager();
|
||||
|
@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|||
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
|
||||
(void)manager->Replace(hccl_node, new_hccl_node);
|
||||
MS_LOG(DEBUG) << "end replace";
|
||||
|
||||
// transer hccl op's control to the memcpy_async
|
||||
TransferControl(new_hccl_node, memcpy_async_list, graph);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
|
|||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
|
||||
for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) {
|
||||
for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx);
|
||||
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
|
||||
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx);
|
||||
|
|
|
@ -40,7 +40,8 @@ const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, cons
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (!input_node->isa<CNode>()) {
|
||||
|
@ -77,7 +78,8 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
|
|||
MS_EXCEPTION_IF_NULL(node_info.first);
|
||||
auto cast_out_node = node_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_out_node);
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
|
||||
cast_node) {
|
||||
continue;
|
||||
|
|
|
@ -162,7 +162,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
|
|||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
AbstractBasePtrList abstract_list;
|
||||
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
|
||||
// deal with ref output
|
||||
if (ref_infos.count(output_index) != 0) {
|
||||
|
|
|
@ -37,7 +37,7 @@ const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto split_v = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(split_v);
|
||||
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3);
|
||||
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), kMatMulInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0);
|
||||
auto input_node = input_node_with_idx.first;
|
||||
|
|
|
@ -129,9 +129,21 @@ AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
|
||||
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
||||
for (auto &output : outputs) {
|
||||
size_t index = 0;
|
||||
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
||||
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
|
||||
MS_EXCEPTION_IF_NULL(tuple_index_value);
|
||||
if (!tuple_index_value->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
|
||||
}
|
||||
index = tuple_index_value->cast<Int64ImmPtr>()->value();
|
||||
}
|
||||
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
|
||||
}
|
||||
for (auto &t : todo) {
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
|
||||
// process input
|
||||
|
|
|
@ -33,7 +33,8 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
|
||||
auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
|
||||
auto input_format = AnfAlgo::GetInputFormat(cnode, index);
|
||||
|
|
|
@ -28,8 +28,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const size_t kCastInputNum = 2;
|
||||
const size_t kTupleGetitemInputNum = 3;
|
||||
bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx,
|
||||
const std::shared_ptr<kernel::KernelBuildInfo> &candidate_kernel_info) {
|
||||
if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) {
|
||||
|
@ -126,7 +124,8 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
|
|||
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
if (cast_index == index) {
|
||||
shapes.emplace_back(cast_shape);
|
||||
types.emplace_back(cast_dtype);
|
||||
|
@ -175,7 +174,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
||||
<< (*alternative_kernel_info)->ToString();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
|
||||
if (node->inputs().size() < kCastInputNum) {
|
||||
if (AnfAlgo::GetInputTensorNum(node) < kCastInputTensorNum) {
|
||||
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
|
||||
}
|
||||
return node->input(1);
|
||||
|
@ -188,9 +187,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
|
|||
*prior_op = x_cnode;
|
||||
// when x_node is tuple_getitem
|
||||
if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) {
|
||||
if (x_cnode->inputs().size() < kTupleGetitemInputNum) {
|
||||
MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size();
|
||||
}
|
||||
CheckCNodeInputSize(x_cnode, kTupleGetItemInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(output_idx);
|
||||
AnfNodePtr input1 = x_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(input1);
|
||||
|
@ -214,9 +211,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
|
|||
AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) {
|
||||
MS_EXCEPTION_IF_NULL(cur_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_query);
|
||||
if (cur_node->inputs().size() < kCastInputNum) {
|
||||
MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:";
|
||||
}
|
||||
CheckCNodeInputSize(cur_node, kCastInputTensorNum);
|
||||
AnfNodePtr x_node = cur_node->input(1);
|
||||
if (IsUsedByOthers(graph, x_node)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -69,7 +69,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CheckCNodeInputSize(cnode, kTransOpInputNum);
|
||||
CheckCNodeInputSize(cnode, kTransOpInputTensorNum);
|
||||
auto input_node = cnode->input(1);
|
||||
if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) {
|
||||
kernel_graph->ReplaceInternalOutput(node, input_node);
|
||||
|
|
|
@ -111,8 +111,8 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|||
|
||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
|
||||
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
|
||||
if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is "
|
||||
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
std::vector<AbstractBasePtr> abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3],
|
||||
|
|
|
@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
const auto &bn_grad_inputs = bn_grad_node->inputs();
|
||||
if (bn_grad_inputs.size() < kBNGradInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size."
|
||||
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
|
||||
}
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
std::vector<AnfNodePtr> bn_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
|
||||
bn_grad_inputs[4], bn_grad_inputs[5]};
|
||||
|
@ -60,10 +57,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs);
|
||||
const auto &bn_grad_inputs = bn_grad_node->inputs();
|
||||
if (bn_grad_inputs.size() < kBNGradInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
|
||||
}
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
|
||||
|
|
|
@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
if (bn_grad_inputs.size() != kBNGradInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
|
||||
}
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
std::vector<AnfNodePtr> bn_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
|
||||
bn_grad_inputs[4], bn_grad_inputs[5]};
|
||||
|
@ -59,10 +56,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
if (bn_grad_inputs.size() != kBNGradInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_grad_node);
|
||||
}
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size";
|
||||
}
|
||||
|
|
|
@ -32,8 +32,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (bn_cnode->inputs().size() != kBnInputNum) {
|
||||
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString();
|
||||
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
|
||||
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
|
@ -64,10 +64,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod
|
|||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (bn_cnode->inputs().size() != kBnInputNum) {
|
||||
MS_LOG(EXCEPTION) << "BN node has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_cnode);
|
||||
}
|
||||
CheckCNodeInputSize(bn_cnode, kBnInputTensorNum);
|
||||
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(bn_cnode);
|
||||
|
@ -102,8 +99,8 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() < kBnInputNum) {
|
||||
MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
|
||||
MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
|
||||
return nullptr;
|
||||
}
|
||||
// Create BNTrainingReduce node and get outputs of BNTrainingReduce
|
||||
|
|
|
@ -123,8 +123,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const
|
|||
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (origin_node->size() != kGatherV2DynInputNum + 1) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputNum
|
||||
if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -28,11 +28,7 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars
|
|||
std::vector<AnfNodePtr> *square_sum_all_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(lars_v2);
|
||||
if (lars_v2->size() != kLarsV2InputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum
|
||||
<< " trace: " << trace::DumpSourceLines(lars_v2);
|
||||
}
|
||||
|
||||
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSquareSumAllOpName)), lars_v2->input(1),
|
||||
lars_v2->input(2)};
|
||||
auto square_sum_all = graph->NewCNode(inputs);
|
||||
|
@ -55,10 +51,7 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2,
|
|||
MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"
|
||||
<< " trace: " << trace::DumpSourceLines(lars_v2);
|
||||
}
|
||||
if (lars_v2->size() != kLarsV2InputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum
|
||||
<< " trace: " << trace::DumpSourceLines(lars_v2);
|
||||
}
|
||||
CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kLarsV2UpdateOpName)),
|
||||
lars_v2->input(1),
|
||||
lars_v2->input(2),
|
||||
|
|
|
@ -91,7 +91,7 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An
|
|||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->inputs().size() != kLayerNormGradInputNum) {
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormGradInputTensorNum) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CheckCNodeInputSize(cnode, 2);
|
||||
CheckCNodeInputSize(cnode, 1);
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0);
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
|
|
|
@ -76,8 +76,8 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod
|
|||
|
||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn->abstract());
|
||||
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
|
||||
if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is "
|
||||
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
bn_training_update_v3->set_abstract(bn->abstract());
|
||||
|
|
|
@ -34,10 +34,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
|||
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
if (origin_cnode->inputs().size() < kSplitInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
|
||||
<< kSplitInputNum - 1 << " trace: " << trace::DumpSourceLines(origin_cnode);
|
||||
}
|
||||
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
|
|
|
@ -32,10 +32,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu
|
|||
|
||||
CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_cnode);
|
||||
if (origin_cnode->inputs().size() < kSplitInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be "
|
||||
<< kSplitInputNum - 1;
|
||||
}
|
||||
CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum);
|
||||
return CreateSplitVNode(func_graph, origin_cnode->input(1));
|
||||
}
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
CheckCNodeInputSize(new_cnode, kTopkInputNum);
|
||||
CheckCNodeInputSize(new_cnode, kTopkInputTensorNum);
|
||||
// Convert the tensor input to scalar and convert it to attr
|
||||
auto input_k = new_cnode->input(kTopkIndexK + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_k);
|
||||
|
|
|
@ -31,7 +31,7 @@ const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const A
|
|||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) {
|
||||
CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum);
|
||||
CheckCNodeInputSize(node->cast<CNodePtr>(), kTransOpInputTensorNum);
|
||||
if (IsFormatInvaild(node)) {
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
return DoSplit(func_graph, node);
|
||||
|
|
|
@ -77,8 +77,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s
|
|||
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum
|
||||
if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ bool CheckIndex(const AnfNodePtr &index_node) {
|
|||
bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnorm);
|
||||
if (batchnorm->size() < kBatchNormInputNum + 1) {
|
||||
MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum;
|
||||
if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "BatchNorm's input less than " << kBnInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) {
|
||||
|
@ -87,7 +87,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
|
||||
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
if (!CheckIndex(index_node)) {
|
||||
|
|
|
@ -61,8 +61,8 @@ bool CheckIndex(const AnfNodePtr &index_node) {
|
|||
bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(batchnormgrad);
|
||||
if (batchnormgrad->size() < kBatchNormInputNum + 1) {
|
||||
MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum;
|
||||
if (AnfAlgo::GetInputTensorNum(batchnormgrad) < kBNGradInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBnInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) {
|
||||
|
@ -86,7 +86,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize);
|
||||
CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum);
|
||||
AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
if (!CheckIndex(index_node)) {
|
||||
|
|
|
@ -79,7 +79,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf
|
|||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(minimum);
|
||||
if (minimum->inputs().size() != kMinimumInputNum) {
|
||||
if (AnfAlgo::GetInputTensorNum(minimum) != kMinimumInputTensorNum) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -30,9 +30,7 @@ const size_t kReluV2OutputNum = 2;
|
|||
|
||||
CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
if (relu_grad->size() != kReluGradInputNum) {
|
||||
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
|
||||
}
|
||||
CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum);
|
||||
auto relu_anf = relu_grad->input(2);
|
||||
MS_EXCEPTION_IF_NULL(relu_anf);
|
||||
return relu_anf->cast<CNodePtr>();
|
||||
|
@ -41,9 +39,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
|||
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
if (relu->size() != kReluInputNum) {
|
||||
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
|
||||
}
|
||||
CheckCNodeInputSize(relu, kReluInputTensorNum);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kReluV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};
|
||||
|
|
|
@ -53,32 +53,9 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect
|
|||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef FusedBatchNormFusion::DefinePattern() const {
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index1 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index2 = std::make_shared<CondVar>(IsC);
|
||||
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
|
||||
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
|
||||
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
|
||||
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
|
||||
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
|
||||
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
|
||||
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
|
||||
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
||||
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
||||
}
|
||||
|
||||
ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto iter_constant_input0 = (*equiv).find(constant_input0_var_);
|
||||
if (iter_constant_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched.";
|
||||
}
|
||||
auto constant_input = utils::cast<AnfNodePtr>(iter_constant_input0->second);
|
||||
auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_);
|
||||
MS_EXCEPTION_IF_NULL(constant_input);
|
||||
if (!constant_input->isa<ValueNode>()) {
|
||||
return nullptr;
|
||||
|
@ -113,31 +90,15 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
// Set input to create node
|
||||
auto iter_data_input0 = (*equiv).find(data_input0_var_);
|
||||
if (iter_data_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)),
|
||||
utils::cast<AnfNodePtr>(iter_data_input0->second)};
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)};
|
||||
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_reduce);
|
||||
bn_training_reduce->set_scope(node->scope());
|
||||
// Set abstract
|
||||
auto iter_data_input1 = (*equiv).find(data_input1_var_);
|
||||
if (iter_data_input1 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
auto data_input1 = utils::cast<AnfNodePtr>(iter_data_input1->second);
|
||||
auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_);
|
||||
MS_EXCEPTION_IF_NULL(data_input1);
|
||||
auto iter_data_input2 = (*equiv).find(data_input2_var_);
|
||||
if (iter_data_input2 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
auto data_input2 = utils::cast<AnfNodePtr>(iter_data_input2->second);
|
||||
auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_);
|
||||
MS_EXCEPTION_IF_NULL(data_input2);
|
||||
AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
|
@ -150,39 +111,15 @@ void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv,
|
|||
std::vector<AnfNodePtr> *bn_training_update_inputs) const {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update_inputs);
|
||||
auto iter_data_input0 = (*equiv).find(data_input0_var_);
|
||||
if (iter_data_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched.";
|
||||
}
|
||||
auto iter_data_input1 = (*equiv).find(data_input1_var_);
|
||||
if (iter_data_input1 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched.";
|
||||
}
|
||||
auto iter_data_input2 = (*equiv).find(data_input2_var_);
|
||||
if (iter_data_input2 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched.";
|
||||
}
|
||||
auto iter_variable_input0 = (*equiv).find(variable_input0_var_);
|
||||
if (iter_variable_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched.";
|
||||
}
|
||||
auto iter_variable_input1 = (*equiv).find(variable_input1_var_);
|
||||
if (iter_variable_input1 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched.";
|
||||
}
|
||||
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
|
||||
<< ", but it is " << bn_training_reduce_outputs.size();
|
||||
}
|
||||
*bn_training_update_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateOpName)),
|
||||
utils::cast<AnfNodePtr>(iter_data_input0->second),
|
||||
utils::cast<AnfNodePtr>(GetAnfNodeByVar(equiv, data_input0_var_)),
|
||||
bn_training_reduce_outputs[0],
|
||||
bn_training_reduce_outputs[1],
|
||||
utils::cast<AnfNodePtr>(iter_data_input1->second),
|
||||
utils::cast<AnfNodePtr>(iter_data_input2->second),
|
||||
utils::cast<AnfNodePtr>(iter_variable_input0->second),
|
||||
utils::cast<AnfNodePtr>(iter_variable_input1->second),
|
||||
GetAnfNodeByVar(equiv, data_input1_var_),
|
||||
GetAnfNodeByVar(equiv, data_input2_var_),
|
||||
GetAnfNodeByVar(equiv, variable_input0_var_),
|
||||
GetAnfNodeByVar(equiv, variable_input1_var_),
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -197,19 +134,9 @@ void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv
|
|||
MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
auto iter_variable_input0 = (*equiv).find(variable_input0_var_);
|
||||
if (iter_variable_input0 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
auto variable_input0 = utils::cast<AnfNodePtr>(iter_variable_input0->second);
|
||||
auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_);
|
||||
auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_);
|
||||
MS_EXCEPTION_IF_NULL(variable_input0);
|
||||
auto iter_variable_input1 = (*equiv).find(variable_input1_var_);
|
||||
if (iter_variable_input1 == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(bn);
|
||||
}
|
||||
auto variable_input1 = utils::cast<AnfNodePtr>(iter_variable_input1->second);
|
||||
MS_EXCEPTION_IF_NULL(variable_input1);
|
||||
*abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(),
|
||||
bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]};
|
||||
|
@ -227,13 +154,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
|
|||
auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn_training_update);
|
||||
// Set abstract
|
||||
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
|
||||
if (iter_batch_norm == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_);
|
||||
AbstractBasePtrList abstract_list;
|
||||
GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list);
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
|
@ -249,6 +170,23 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
|
|||
return bn_training_update;
|
||||
}
|
||||
|
||||
void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_);
|
||||
MS_EXCEPTION_IF_NULL(assign_sub1);
|
||||
for (const auto &node_index : manager->node_users()[assign_sub1]) {
|
||||
const AnfNodePtr &output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) {
|
||||
(void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -271,14 +209,8 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
<< bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
// Replace old bn outputs with new outputs
|
||||
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
|
||||
if (iter_batch_norm == (*equiv).end()) {
|
||||
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
|
||||
std::vector<AnfNodePtr> bn_outputs;
|
||||
GetBNOutput(func_graph, bn, &bn_outputs);
|
||||
GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (const auto &output : bn_outputs) {
|
||||
|
@ -297,7 +229,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
(void)manager->Replace(output, bn_training_update_outputs[index]);
|
||||
}
|
||||
}
|
||||
return bn_training_update_outputs[0];
|
||||
(void)manager->Replace(node, bn_training_update_outputs[0]);
|
||||
EliminateMonadNodes(func_graph, equiv);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const BaseRef FusedBatchNormFusion::DefinePattern() const {
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index1 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index2 = std::make_shared<CondVar>(IsC);
|
||||
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
|
||||
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
|
||||
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
|
||||
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
|
||||
VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1});
|
||||
VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
|
||||
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
|
||||
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
||||
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
||||
}
|
||||
|
||||
const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
|
||||
|
@ -317,8 +270,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const {
|
|||
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
|
||||
VectorRef cast2 = VectorRef({prim::kPrimCast, mul0});
|
||||
VectorRef cast3 = VectorRef({prim::kPrimCast, mul1});
|
||||
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2});
|
||||
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3});
|
||||
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_});
|
||||
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
||||
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
||||
}
|
||||
|
@ -340,8 +293,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const {
|
|||
VectorRef cast1 = VectorRef({prim::kPrimCast, sub1});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_});
|
||||
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0});
|
||||
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1});
|
||||
VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_});
|
||||
VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
||||
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
||||
}
|
||||
|
|
|
@ -27,15 +27,20 @@ namespace opt {
|
|||
class FusedBatchNormFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
data_input0_var_(std::make_shared<Var>()),
|
||||
data_input1_var_(std::make_shared<Var>()),
|
||||
data_input2_var_(std::make_shared<Var>()),
|
||||
variable_input0_var_(std::make_shared<Var>()),
|
||||
variable_input1_var_(std::make_shared<Var>()),
|
||||
constant_input0_var_(std::make_shared<Var>()),
|
||||
constant_input1_var_(std::make_shared<Var>()),
|
||||
batch_norm_var_(std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()))) {}
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
data_input0_var_ = std::make_shared<Var>();
|
||||
data_input1_var_ = std::make_shared<Var>();
|
||||
data_input2_var_ = std::make_shared<Var>();
|
||||
variable_input0_var_ = std::make_shared<Var>();
|
||||
variable_input1_var_ = std::make_shared<Var>();
|
||||
constant_input0_var_ = std::make_shared<Var>();
|
||||
constant_input1_var_ = std::make_shared<Var>();
|
||||
batch_norm_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()));
|
||||
assign_sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
|
||||
assign_sub1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
|
||||
monad0_var_ = std::make_shared<Var>();
|
||||
monad1_var_ = std::make_shared<Var>();
|
||||
}
|
||||
~FusedBatchNormFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
@ -50,6 +55,7 @@ class FusedBatchNormFusion : public PatternProcessPass {
|
|||
AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
|
||||
const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
|
||||
ValuePtr GetFactor(const EquivPtr &equiv) const;
|
||||
void EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
|
||||
|
||||
VarPtr data_input0_var_;
|
||||
VarPtr data_input1_var_;
|
||||
|
@ -59,6 +65,10 @@ class FusedBatchNormFusion : public PatternProcessPass {
|
|||
VarPtr constant_input0_var_;
|
||||
VarPtr constant_input1_var_;
|
||||
VarPtr batch_norm_var_;
|
||||
VarPtr assign_sub0_var_;
|
||||
VarPtr assign_sub1_var_;
|
||||
VarPtr monad0_var_;
|
||||
VarPtr monad1_var_;
|
||||
};
|
||||
|
||||
class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion {
|
||||
|
|
|
@ -30,33 +30,21 @@ std::tuple<AnfNodePtr, AnfNodePtr, AnfNodePtr, AnfNodePtr> GetSharedNodes(const
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto add3 = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add3);
|
||||
if (add3->inputs().size() < kAddInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
CheckCNodeInputSize(add3, kAddInputTensorNum);
|
||||
auto real_div2_anf = add3->input(1);
|
||||
MS_EXCEPTION_IF_NULL(real_div2_anf);
|
||||
auto real_div2 = real_div2_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div2);
|
||||
if (real_div2->inputs().size() < kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
CheckCNodeInputSize(real_div2, kRealDivInputTensorNum);
|
||||
auto sqrt0_anf = real_div2->input(2);
|
||||
MS_EXCEPTION_IF_NULL(sqrt0_anf);
|
||||
auto sqrt0 = sqrt0_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sqrt0);
|
||||
if (sqrt0->inputs().size() < kRsqrtInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
CheckCNodeInputSize(sqrt0, kSqrtInputTensorNum);
|
||||
auto add2_anf = sqrt0->input(1);
|
||||
MS_EXCEPTION_IF_NULL(add2_anf);
|
||||
auto add2 = add2_anf->cast<CNodePtr>();
|
||||
if (add2->inputs().size() < kAddInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
CheckCNodeInputSize(add2, kAddInputTensorNum);
|
||||
return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2));
|
||||
}
|
||||
|
||||
|
@ -66,7 +54,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
|||
return false;
|
||||
}
|
||||
auto add5 = node->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) {
|
||||
if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add5) != kAddInputTensorNum) {
|
||||
return false;
|
||||
}
|
||||
auto real_div4_anf = add5->input(1);
|
||||
|
@ -74,7 +62,8 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
|||
return false;
|
||||
}
|
||||
auto real_div4 = real_div4_anf->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) {
|
||||
if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName ||
|
||||
AnfAlgo::GetInputTensorNum(real_div4) != kRealDivInputTensorNum) {
|
||||
return false;
|
||||
}
|
||||
auto add4_anf = real_div4->input(2);
|
||||
|
@ -82,7 +71,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
|||
return false;
|
||||
}
|
||||
auto add4 = add4_anf->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) {
|
||||
if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add4) != kAddInputTensorNum) {
|
||||
return false;
|
||||
}
|
||||
auto sqrt1_anf = add4->input(1);
|
||||
|
@ -90,7 +79,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN
|
|||
return false;
|
||||
}
|
||||
auto sqrt1 = sqrt1_anf->cast<CNodePtr>();
|
||||
if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) {
|
||||
if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || AnfAlgo::GetInputTensorNum(sqrt1) != kSqrtInputTensorNum) {
|
||||
return false;
|
||||
}
|
||||
return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 &&
|
||||
|
@ -104,14 +93,8 @@ std::tuple<AnfNodePtr, AnfNodePtr> GetAdd0Add1Nodes(const AnfNodePtr &real_div0_
|
|||
auto real_div1 = real_div1_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(real_div0);
|
||||
MS_EXCEPTION_IF_NULL(real_div1);
|
||||
if (real_div0->inputs().size() != kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(real_div0_anf);
|
||||
}
|
||||
if (real_div1->inputs().size() != kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(real_div1_anf);
|
||||
}
|
||||
CheckCNodeInputSize(real_div0, kRealDivInputTensorNum);
|
||||
CheckCNodeInputSize(real_div1, kRealDivInputTensorNum);
|
||||
return std::make_tuple(real_div0->input(1), real_div1->input(1));
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -77,9 +77,9 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode
|
|||
MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr";
|
||||
return false;
|
||||
}
|
||||
if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) {
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormBetaGammaBackpropInputTensorNum) {
|
||||
MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to "
|
||||
<< kLayerNormBetaGammaBackpropInputNum;
|
||||
<< kLayerNormBetaGammaBackpropInputTensorNum;
|
||||
return false;
|
||||
}
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) {
|
||||
|
@ -87,7 +87,8 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode
|
|||
<< kLayerNormBetaGammaBackpropOutputNum;
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) {
|
||||
MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16";
|
||||
return false;
|
||||
|
@ -148,15 +149,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
|
|||
// The cast_nodes size has been checked above.
|
||||
MS_EXCEPTION_IF_NULL(cast_nodes[0]);
|
||||
MS_EXCEPTION_IF_NULL(cast_nodes[1]);
|
||||
if (cast_nodes[0]->inputs().size() != kCastInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
CheckCNodeInputSize(cast_nodes[0], kCastInputTensorNum);
|
||||
CheckCNodeInputSize(cast_nodes[1], kCastInputTensorNum);
|
||||
(void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1));
|
||||
if (cast_nodes[1]->inputs().size() != kCastInputNum) {
|
||||
MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
(void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,20 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
|
||||
if (matmul == nullptr || !matmul->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
// If there is a side-effect operator in the fusion, do not merge
|
||||
MonadState state_matmul = GetMonadState(matmul);
|
||||
MonadState state_node = GetMonadState(node, matmul);
|
||||
if (!IsStateEquivalent(state_matmul, state_node)) {
|
||||
return node;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
|
||||
|
@ -41,11 +55,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
|
|||
new_node->set_scope(node->scope());
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
||||
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
|
||||
if (matmul == nullptr || !matmul->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
AnfAlgo::CopyNodeAttrs(matmul, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,9 @@ const BaseRef MomentumLossscaleFusion::DefinePattern() const {
|
|||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr X2 = std::make_shared<Var>();
|
||||
VarPtr X4 = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4});
|
||||
// UpdateState node
|
||||
VarPtr X5 = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4, X5});
|
||||
}
|
||||
|
||||
const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -52,14 +54,15 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CheckCNodeInputSize(cnode, kApplyMomentumInputNum);
|
||||
CheckCNodeInputSize(cnode, kApplyMomentumInputTensorNum);
|
||||
AnfNodePtr mul = cnode->input(4);
|
||||
MS_EXCEPTION_IF_NULL(mul);
|
||||
auto mul_cnode = mul->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul_cnode);
|
||||
CheckCNodeInputSize(mul_cnode, kMulInputNum);
|
||||
CheckCNodeInputSize(mul_cnode, kMulInputTensorNum);
|
||||
size_t value_node_index = 0;
|
||||
for (size_t i = 1; i < kMulInputNum; ++i) {
|
||||
// All real inputs include 1prim + x*TensorInput
|
||||
for (size_t i = 1; i < kMulInputTensorNum + 1; ++i) {
|
||||
if (CheckValueNodeInputOfMul(mul_cnode->input(i))) {
|
||||
value_node_index = i;
|
||||
break;
|
||||
|
@ -70,12 +73,16 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
|
|||
return nullptr;
|
||||
}
|
||||
auto new_prim = std::make_shared<Primitive>(kFusedMulApplyMomentumOpName);
|
||||
auto depend_prim = NewValueNode(prim::kPrimDepend);
|
||||
auto depend = func_graph->NewCNode({depend_prim, cnode->input(5), cnode->input(6)}); // depend on monad
|
||||
depend->set_abstract(cnode->input(5)->abstract());
|
||||
depend->set_scope(cnode->input(5)->scope());
|
||||
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(new_prim),
|
||||
cnode->input(1),
|
||||
cnode->input(2),
|
||||
cnode->input(3),
|
||||
mul_cnode->input(kMulInputNum - value_node_index),
|
||||
cnode->input(5),
|
||||
mul_cnode->input(kMulInputTensorNum + 1 - value_node_index),
|
||||
depend,
|
||||
mul_cnode->input(value_node_index)};
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
|
|
@ -67,7 +67,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
return nullptr;
|
||||
}
|
||||
auto add = node->cast<CNodePtr>();
|
||||
if (add == nullptr || add->inputs().size() != kAddInputNum) {
|
||||
if (add == nullptr || AnfAlgo::GetInputTensorNum(add) != kAddInputTensorNum) {
|
||||
return nullptr;
|
||||
}
|
||||
CNodePtr mul = nullptr;
|
||||
|
|
|
@ -31,7 +31,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
|
|||
MS_EXCEPTION_IF_NULL(addn);
|
||||
auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
inputs.push_back(mul->input(kMulInputNum - lossscale_input_index));
|
||||
inputs.push_back(mul->input(kMulInputTensorNum + 1 - lossscale_input_index));
|
||||
inputs.push_back(addn->input(2));
|
||||
// scalar input should be 3rd input
|
||||
inputs.push_back(mul->input(lossscale_input_index));
|
||||
|
@ -60,7 +60,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
|
|||
}
|
||||
|
||||
auto addn = node->cast<CNodePtr>();
|
||||
if (addn == nullptr || addn->inputs().size() != kAddNInputNum) {
|
||||
if (addn == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mul_anf = addn->input(1);
|
||||
|
@ -68,7 +68,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
|
|||
return nullptr;
|
||||
}
|
||||
auto mul = mul_anf->cast<CNodePtr>();
|
||||
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
|
||||
if (mul == nullptr || AnfAlgo::GetInputTensorNum(mul) != kMulInputTensorNum) {
|
||||
return nullptr;
|
||||
}
|
||||
if (IsUsedByOthers(graph, mul)) {
|
||||
|
|
|
@ -98,7 +98,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(DEBUG) << "Skip trans op";
|
||||
continue;
|
||||
}
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||
std::vector<CNodePtr> trans_road;
|
||||
bool first_flag = true;
|
||||
auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road);
|
||||
|
|
|
@ -26,8 +26,10 @@ void DoRefresh(const CNodePtr &cnode) {
|
|||
if (cnode == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "node is nullptr";
|
||||
}
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
|
||||
auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||
auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (input_kernel_node->isa<Parameter>()) {
|
||||
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
|
|
@ -34,13 +34,14 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
||||
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(out_reshape);
|
||||
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
|
||||
if (IsUsedByOthers(func_graph, out_reshape)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum);
|
||||
auto in_reshape =
|
||||
CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(in_reshape);
|
||||
if (IsUsedByOthers(func_graph, in_reshape)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -46,9 +46,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -33,10 +33,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
if (square->inputs().size() != kSquareNodeInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Square node has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(square);
|
||||
}
|
||||
CheckCNodeInputSize(square, kSquareNodeInputTensorNum);
|
||||
auto prim = std::make_shared<Primitive>(kSquareSumV1OpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> square_sumv1_inputs = {NewValueNode(prim), square->input(1)};
|
||||
|
@ -60,10 +57,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(square);
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
if (square->inputs().size() != kSquareNodeInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Square node has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(square);
|
||||
}
|
||||
CheckCNodeInputSize(square, kSquareNodeInputTensorNum);
|
||||
auto prim = std::make_shared<Primitive>(kSquareSumV2OpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> square_sumv2_inputs = {NewValueNode(prim), square->input(1)};
|
||||
|
@ -84,10 +78,7 @@ std::tuple<CNodePtr, AnfNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node)
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto sum = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
if (sum->inputs().size() != kSumNodeInputNum) {
|
||||
MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"
|
||||
<< " trace: " << trace::DumpSourceLines(sum);
|
||||
}
|
||||
CheckCNodeInputSize(sum, kSumNodeInputTensorNum);
|
||||
auto square_anf = sum->input(1);
|
||||
MS_EXCEPTION_IF_NULL(square_anf);
|
||||
auto square = square_anf->cast<CNodePtr>();
|
||||
|
|
|
@ -46,9 +46,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
|
||||
return nullptr;
|
||||
|
|
|
@ -33,9 +33,9 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum);
|
||||
auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transdata_cnode);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kTransOpInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||
auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode);
|
||||
auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode);
|
||||
|
|
|
@ -136,10 +136,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons
|
|||
CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(conv2d);
|
||||
if (conv2d->inputs().size() != kConvInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
|
||||
<< conv2d->inputs().size() - 1;
|
||||
}
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
std::vector<AnfNodePtr> depth_conv_inputs = {NewValueNode(std::make_shared<Primitive>(kDepthwiseConv2dNativeOpName)),
|
||||
conv2d->input(1), transpose};
|
||||
auto depth_conv = graph->NewCNode(depth_conv_inputs);
|
||||
|
@ -270,11 +267,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf
|
|||
if (!NeedUpdate(conv2d, input_shape, output_shape)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (conv2d->inputs().size() != kConvInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got "
|
||||
<< conv2d->inputs().size() - 1;
|
||||
}
|
||||
CheckCNodeInputSize(conv2d, kConvInputTensorNum);
|
||||
auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true);
|
||||
auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose);
|
||||
SetConv2DAttrs(conv2d, depth_conv);
|
||||
|
|
|
@ -70,7 +70,8 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const {
|
|||
VarPtr l1 = std::make_shared<Var>();
|
||||
VarPtr l2 = std::make_shared<Var>();
|
||||
VarPtr lr_power = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power});
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
|
@ -84,7 +85,8 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const {
|
|||
VarPtr lr = std::make_shared<Var>();
|
||||
VarPtr grad = std::make_shared<Var>();
|
||||
VarPtr momentum = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum});
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
|
@ -114,7 +116,8 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
|
|||
VarPtr rho = std::make_shared<Var>();
|
||||
VarPtr momentum = std::make_shared<Var>();
|
||||
VarPtr epsilon = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon});
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
|
|
|
@ -109,12 +109,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(one_hot_node);
|
||||
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kSoftmaxCrossEntropyWithLogitsOpName)),
|
||||
sparse_softmax_node->input(1), one_hot_node};
|
||||
auto softmax_node = graph->NewCNode(inputs);
|
||||
|
@ -162,10 +157,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(softmax_output_node);
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
auto axis_value = GetAxis(softmax_output_node);
|
||||
auto axis_node = GetAxisNode(softmax_output_node);
|
||||
|
@ -200,9 +192,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft
|
|||
CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
if (real_div_node->size() != kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
|
||||
|
||||
int64_t axis = -1;
|
||||
auto axis_node = NewValueNode(axis);
|
||||
|
@ -230,9 +220,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no
|
|||
CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(real_div_node);
|
||||
if (real_div_node->size() != kRealDivInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum);
|
||||
|
||||
int64_t axis = -1;
|
||||
auto expand_dims_primitive = std::make_shared<Primitive>(kExpandDimsOpName);
|
||||
std::vector<std::string> input_names = {"x"};
|
||||
|
@ -257,13 +246,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
if (mul_node->size() != kMulInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
||||
std::vector<int64_t> multiple_value;
|
||||
|
@ -310,12 +294,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
MS_EXCEPTION_IF_NULL(tile_node);
|
||||
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
std::vector<size_t> labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1);
|
||||
if (labels_shape.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "label's shape should be 1-D.";
|
||||
|
@ -343,9 +322,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax
|
|||
|
||||
CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
if (depend_node->size() != kDependInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Depend's input not equal " << kDependInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(depend_node, kDependInputTensorNum);
|
||||
auto sparse_node = depend_node->input(index);
|
||||
MS_EXCEPTION_IF_NULL(sparse_node);
|
||||
return sparse_node->cast<CNodePtr>();
|
||||
|
@ -353,9 +330,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) {
|
|||
|
||||
CNodePtr GetDependNode(const CNodePtr &mul_node) {
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (mul_node->size() != kMulInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
auto depend_node = mul_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
return depend_node->cast<CNodePtr>();
|
||||
|
@ -413,10 +388,7 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F
|
|||
|
||||
auto sparse_softmax_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return nullptr;
|
||||
|
@ -451,17 +423,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con
|
|||
|
||||
auto mul_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (mul_node->size() != kMulInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto depend_node = GetDependNode(mul_node);
|
||||
auto sparse_softmax_node = GetSparseNode(depend_node, 2);
|
||||
auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1);
|
||||
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
|
@ -538,10 +505,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
|
||||
auto sparse_softmax_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node);
|
||||
if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) &&
|
||||
AnfAlgo::GetNodeAttr<bool>(sparse_softmax_node, kAttrIsGrad)) {
|
||||
return nullptr;
|
||||
|
@ -573,17 +538,12 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|||
|
||||
auto mul_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mul_node);
|
||||
if (mul_node->size() != kMulInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(mul_node, kMulInputTensorNum);
|
||||
|
||||
auto sparse_softmax_node = mul_node->input(1);
|
||||
auto sparse_softmax_node_grad = sparse_softmax_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad);
|
||||
|
||||
if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal "
|
||||
<< kSparseSoftmaxCrossEntropyWithLogitsInputNum;
|
||||
}
|
||||
CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum);
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad);
|
||||
|
|
|
@ -124,18 +124,16 @@ CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_si
|
|||
MS_LOG(EXCEPTION) << "The node is expected to be a cnode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() != input_size) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs.";
|
||||
}
|
||||
CheckCNodeInputSize(cnode, input_size);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) {
|
||||
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() != input_size) {
|
||||
MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size;
|
||||
auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
if (real_input_tensor_num != input_tensor_size) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num
|
||||
<< "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,17 +147,15 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
|
|||
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
|
||||
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transop_cnode);
|
||||
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
|
||||
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
|
||||
MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1));
|
||||
auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1);
|
||||
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum);
|
||||
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum);
|
||||
auto transed_node = prev_transop_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(transed_node);
|
||||
|
||||
std::vector<AnfNodePtr> replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node,
|
||||
depend_cnode->input(kDependInputNum - 1)};
|
||||
depend_cnode->input(kDependAttachNodeIndex)};
|
||||
AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_depend);
|
||||
auto transed_abstract = transed_node->abstract();
|
||||
|
@ -422,13 +418,13 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
|
||||
output_info.second == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) {
|
||||
continue;
|
||||
}
|
||||
output_node_list->push_back(output_info);
|
||||
}
|
||||
return output_node_list;
|
||||
|
@ -537,6 +533,9 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
|
|||
bool need_update = false;
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) {
|
||||
input_node = AnfAlgo::VisitKernel(input_node, 0).first;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
|
@ -548,7 +547,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
|
|||
primitive->set_attr(input_names_vec[i], value_node->value());
|
||||
need_update = true;
|
||||
} else {
|
||||
new_inputs.push_back(input_node);
|
||||
new_inputs.push_back(inputs[i + 1]);
|
||||
}
|
||||
}
|
||||
if (need_update) {
|
||||
|
@ -785,7 +784,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
|
|||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
// set value node initial device data type = infer data type
|
||||
std::vector<TypeId> types;
|
||||
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(value_node);
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
types.push_back(kTypeUnknown);
|
||||
}
|
||||
kernel_build_info_builder->SetOutputsDeviceType(types);
|
||||
|
|
|
@ -29,36 +29,34 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kTransOpInputNum = 2;
|
||||
constexpr size_t kCastInputNum = 2;
|
||||
constexpr size_t kDependInputNum = 3;
|
||||
constexpr size_t kReluInputNum = 2;
|
||||
constexpr size_t kReluGradInputNum = 3;
|
||||
constexpr size_t kAddInputNum = 3;
|
||||
constexpr size_t kAddNInputNum = 3;
|
||||
constexpr size_t kTupleGetitemInputNum = 3;
|
||||
constexpr size_t kConvInputNum = 3;
|
||||
constexpr size_t kRealDivInputNum = 3;
|
||||
constexpr size_t kSqrtInputNum = 2;
|
||||
constexpr size_t kMulInputNum = 3;
|
||||
constexpr size_t kRsqrtInputNum = 2;
|
||||
constexpr size_t kSubInputNum = 3;
|
||||
constexpr size_t kAssignSubInputNum = 3;
|
||||
constexpr size_t kDropoutInputNum = 2;
|
||||
constexpr size_t kTransOpInputTensorNum = 1;
|
||||
constexpr size_t kCastInputTensorNum = 1;
|
||||
constexpr size_t kDependInputTensorNum = 2;
|
||||
constexpr size_t kReluInputTensorNum = 1;
|
||||
constexpr size_t kReluGradInputTensorNum = 2;
|
||||
constexpr size_t kAddInputTensorNum = 2;
|
||||
constexpr size_t kTupleGetItemInputTensorNum = 2;
|
||||
constexpr size_t kConvInputTensorNum = 2;
|
||||
constexpr size_t kRealDivInputTensorNum = 2;
|
||||
constexpr size_t kSqrtInputTensorNum = 1;
|
||||
constexpr size_t kMatMulInputTensorNum = 2;
|
||||
constexpr size_t kMulInputTensorNum = 2;
|
||||
constexpr size_t kSubInputTensorNum = 2;
|
||||
constexpr size_t kAssignSubInputTensorNum = 2;
|
||||
constexpr size_t kDropoutInputTensorNum = 1;
|
||||
constexpr size_t kAssignInputTensorNum = 2;
|
||||
|
||||
constexpr size_t kConvBn1OutputNum = 3;
|
||||
constexpr size_t kBn2ReluOutputNum = 4;
|
||||
|
||||
constexpr size_t kBnInputNum = 6;
|
||||
constexpr size_t kBnInputTensorNum = 5;
|
||||
constexpr size_t kBnOutputNum = 5;
|
||||
constexpr size_t kBatchNormInputNum = 5;
|
||||
constexpr size_t kBatchNormOutputNum = 5;
|
||||
|
||||
constexpr size_t kBN1OutputNum = 2;
|
||||
constexpr size_t kBN2OutputNum = 3;
|
||||
constexpr size_t kBN3OutputNum = 1;
|
||||
|
||||
constexpr size_t kBNGradInputNum = 6;
|
||||
constexpr size_t kBNGradInputTensorNum = 5;
|
||||
constexpr size_t kBNGradOutputNum = 3;
|
||||
|
||||
constexpr size_t kBNGrad1OutputNum = 3;
|
||||
|
@ -72,10 +70,10 @@ constexpr size_t kBNTrainingUpdateV3OutputNum = 5;
|
|||
constexpr size_t kBNTrainingUpdateGradOutputNum = 2;
|
||||
|
||||
constexpr size_t kSingleOutputNum = 1;
|
||||
constexpr size_t kSumNodeInputNum = 2;
|
||||
constexpr size_t kSquareNodeInputNum = 2;
|
||||
constexpr size_t kSumNodeInputTensorNum = 1;
|
||||
constexpr size_t kSquareNodeInputTensorNum = 1;
|
||||
constexpr size_t kSquareSumv2OutputNum = 2;
|
||||
constexpr size_t kMinimumInputNum = 3;
|
||||
constexpr size_t kMinimumInputTensorNum = 2;
|
||||
|
||||
constexpr size_t kLambNextMVWithDecayInputNum = 7;
|
||||
constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5;
|
||||
|
@ -85,26 +83,25 @@ constexpr size_t kLambNextRightOutputNum = 2;
|
|||
constexpr size_t kLambUpdateWithLrV2InputNum = 8;
|
||||
constexpr size_t kLambNextMVRuleInputNum = 14;
|
||||
constexpr size_t kLambNextMVRuleOutputNum = 4;
|
||||
constexpr size_t kBackendReshapeInputNum = 2;
|
||||
constexpr size_t kBackendTransposeInputNum = 2;
|
||||
constexpr size_t kBackendReshapeInputTensorNum = 1;
|
||||
constexpr size_t kBackendTransposeInputTensorNum = 1;
|
||||
constexpr size_t kAdamApplyOneWithDecayOutputNum = 3;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4;
|
||||
constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2;
|
||||
constexpr size_t kLayerNormGradInputNum = 6;
|
||||
constexpr size_t kLayerNormGradInputTensorNum = 5;
|
||||
constexpr size_t kAdamApplyOneOutputNum = 3;
|
||||
constexpr size_t kBackendTransDataInputNum = 2;
|
||||
constexpr size_t kApplyMomentumInputNum = 6;
|
||||
constexpr size_t kBiasAddInputNum = 3;
|
||||
constexpr size_t kTopkInputNum = 3;
|
||||
constexpr size_t kLarsV2InputNum = 5;
|
||||
constexpr size_t kApplyMomentumInputTensorNum = 5;
|
||||
constexpr size_t kBiasAddInputTensorNum = 2;
|
||||
constexpr size_t kTopkInputTensorNum = 2;
|
||||
constexpr size_t kLarsV2InputTensorNum = 4;
|
||||
constexpr size_t kFusedMulApplyMomentumOutputNum = 2;
|
||||
constexpr size_t kSplitInputNum = 2;
|
||||
constexpr size_t kGatherV2DynInputNum = 3;
|
||||
constexpr size_t kUnsortedSegmentSumInputNum = 2;
|
||||
constexpr size_t kSplitInputTensorNum = 1;
|
||||
constexpr size_t kGatherV2DynInputTensorNum = 3;
|
||||
constexpr size_t kUnsortedSegmentSumInputTensorNum = 2;
|
||||
constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2;
|
||||
constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3;
|
||||
constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum = 2;
|
||||
constexpr size_t kOneHotOutputNum = 1;
|
||||
constexpr size_t kOneHotInputNum = 5;
|
||||
constexpr size_t kOneHotInputTensorNum = 4;
|
||||
|
||||
enum FusedBatchNormInput {
|
||||
kX = 1,
|
||||
|
@ -137,7 +134,7 @@ bool Visited(const BaseRef &n);
|
|||
// check if the input node is CNode, then check it's input_size, return CNodePtr if check success.
|
||||
CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size);
|
||||
|
||||
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size);
|
||||
void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_num);
|
||||
|
||||
bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y);
|
||||
|
||||
|
|
|
@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
|
@ -51,19 +53,30 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
} // namespace
|
||||
|
||||
const BaseRef AdamFusion::DefinePattern() const {
|
||||
VectorRef next_m = VectorRef(
|
||||
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
|
||||
VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
|
||||
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
|
||||
VectorRef next_v =
|
||||
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||
|
||||
VectorRef update =
|
||||
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
|
||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
|
||||
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
|
||||
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});
|
||||
|
||||
VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});
|
||||
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_m});
|
||||
|
||||
VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_v});
|
||||
return next_param;
|
||||
}
|
||||
|
||||
|
@ -81,6 +94,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]);
|
||||
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||
|
@ -91,13 +105,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(m_input);
|
||||
MS_EXCEPTION_IF_NULL(v_input);
|
||||
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||
MS_EXCEPTION_IF_NULL(u_input);
|
||||
|
||||
// Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators.
|
||||
auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name());
|
||||
MS_EXCEPTION_IF_NULL(prim_depend);
|
||||
std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input};
|
||||
auto param = graph->NewCNode(param_inputs);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
param->set_abstract(param_input->abstract());
|
||||
|
||||
// Fused into a FusedAdam operator.
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||
eps_input, lr_input, param_input, m_input, v_input,
|
||||
gradient_input};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
|
||||
beta1_input,
|
||||
one_sub_beta1_input,
|
||||
beta2_input,
|
||||
one_sub_beta2_input,
|
||||
eps_input,
|
||||
lr_input,
|
||||
param,
|
||||
m_input,
|
||||
v_input,
|
||||
gradient_input};
|
||||
auto adam = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
@ -107,6 +138,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
|
||||
auto build_info = GenerateKernelBuildInfo(adam);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get());
|
||||
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdam and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
auto n = node->cast<CNodePtr>()->input(2);
|
||||
auto fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return adam;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -34,6 +34,8 @@ class AdamFusion : public PatternProcessPass {
|
|||
m_ = std::make_shared<Var>();
|
||||
v_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
u_ = std::make_shared<Var>();
|
||||
u2_ = std::make_shared<Var>();
|
||||
}
|
||||
~AdamFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -50,6 +52,8 @@ class AdamFusion : public PatternProcessPass {
|
|||
VarPtr m_;
|
||||
VarPtr v_;
|
||||
VarPtr gradient_;
|
||||
VarPtr u_;
|
||||
VarPtr u2_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
|
@ -51,11 +53,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
} // namespace
|
||||
|
||||
const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
||||
VectorRef next_m = VectorRef(
|
||||
{prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
|
||||
VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
|
||||
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
|
||||
VectorRef next_v =
|
||||
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}),
|
||||
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
|
||||
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
|
||||
|
||||
VectorRef update =
|
||||
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
|
||||
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
|
||||
|
@ -63,9 +68,16 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const {
|
|||
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
|
||||
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
|
||||
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
|
||||
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
|
||||
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});
|
||||
|
||||
VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});
|
||||
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_m});
|
||||
|
||||
VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state});
|
||||
next_param = VectorRef({prim::kPrimDepend, next_param, assign_v});
|
||||
return next_param;
|
||||
}
|
||||
|
||||
|
@ -85,6 +97,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
|||
auto m_input = utils::cast<AnfNodePtr>((*equiv)[m_]);
|
||||
auto v_input = utils::cast<AnfNodePtr>((*equiv)[v_]);
|
||||
auto gradient_input = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto u_input = utils::cast<AnfNodePtr>((*equiv)[u_]);
|
||||
MS_EXCEPTION_IF_NULL(beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(one_sub_beta1_input);
|
||||
MS_EXCEPTION_IF_NULL(beta2_input);
|
||||
|
@ -96,13 +109,31 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
|||
MS_EXCEPTION_IF_NULL(m_input);
|
||||
MS_EXCEPTION_IF_NULL(v_input);
|
||||
MS_EXCEPTION_IF_NULL(gradient_input);
|
||||
MS_EXCEPTION_IF_NULL(u_input);
|
||||
|
||||
// Use depend(param, u) to maintain the execution order of FusedAdamWeightDecay and the previous operators.
|
||||
auto prim_depend = std::make_shared<Primitive>(prim::kPrimDepend->name());
|
||||
MS_EXCEPTION_IF_NULL(prim_depend);
|
||||
std::vector<AnfNodePtr> param_inputs = {NewValueNode(prim_depend), param_input, u_input};
|
||||
auto param = graph->NewCNode(param_inputs);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
param->set_abstract(param_input->abstract());
|
||||
|
||||
// Fused into a FusedAdamWeightDecay operator.
|
||||
auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {
|
||||
NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input,
|
||||
eps_input, lr_input, param_input, m_input, v_input,
|
||||
gradient_input, weight_decay_input};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim),
|
||||
beta1_input,
|
||||
one_sub_beta1_input,
|
||||
beta2_input,
|
||||
one_sub_beta2_input,
|
||||
eps_input,
|
||||
lr_input,
|
||||
param,
|
||||
m_input,
|
||||
v_input,
|
||||
gradient_input,
|
||||
weight_decay_input};
|
||||
auto adam_weight_decay = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(adam_weight_decay);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
@ -112,6 +143,30 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const
|
|||
|
||||
auto build_info = GenerateKernelBuildInfo(adam_weight_decay);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get());
|
||||
|
||||
// Replace the parameters of the last UpdateState to maintain
|
||||
// the execution order of FusedAdamWeightDecay and the following operators.
|
||||
// n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v}
|
||||
auto n = node->cast<CNodePtr>()->input(2);
|
||||
auto fg = n->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mgr = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mgr);
|
||||
auto &node_users = mgr->node_users();
|
||||
auto iter = node_users.find(n);
|
||||
if (iter == node_users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString();
|
||||
}
|
||||
|
||||
auto &users = iter->second;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) {
|
||||
(user.first)->cast<CNodePtr>()->set_input(1, u_input);
|
||||
(user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return adam_weight_decay;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -35,6 +35,8 @@ class AdamWeightDecayFusion : public PatternProcessPass {
|
|||
m_ = std::make_shared<Var>();
|
||||
v_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
u_ = std::make_shared<Var>();
|
||||
u2_ = std::make_shared<Var>();
|
||||
}
|
||||
~AdamWeightDecayFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -52,6 +54,8 @@ class AdamWeightDecayFusion : public PatternProcessPass {
|
|||
VarPtr m_;
|
||||
VarPtr v_;
|
||||
VarPtr gradient_;
|
||||
VarPtr u_;
|
||||
VarPtr u2_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
|
|
|
@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
|
@ -78,7 +80,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
|
|||
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(node, i));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(node, i));
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) {
|
|||
const BaseRef ApplyMomentumScaleFusion::DefinePattern() const {
|
||||
VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_});
|
||||
VectorRef apply_momentum =
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
|
||||
return apply_momentum;
|
||||
}
|
||||
|
||||
|
@ -66,17 +66,19 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co
|
|||
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
|
||||
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
|
||||
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(variable);
|
||||
MS_EXCEPTION_IF_NULL(accumulation);
|
||||
MS_EXCEPTION_IF_NULL(learning_rate);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(momentum);
|
||||
MS_EXCEPTION_IF_NULL(monad_state);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedScaleApplyMomentum);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), scale, variable, accumulation,
|
||||
learning_rate, gradient, momentum};
|
||||
learning_rate, gradient, momentum, monad_state};
|
||||
auto replace_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
|
|
@ -31,6 +31,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
|
|||
learning_rate_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
momentum_ = std::make_shared<Var>();
|
||||
monad_state_ = std::make_shared<Var>();
|
||||
}
|
||||
~ApplyMomentumScaleFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -45,6 +46,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass {
|
|||
VarPtr learning_rate_;
|
||||
VarPtr gradient_;
|
||||
VarPtr momentum_;
|
||||
VarPtr monad_state_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,10 +49,11 @@ bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) {
|
|||
}
|
||||
|
||||
const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const {
|
||||
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
|
||||
VectorRef weight_decay =
|
||||
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_});
|
||||
VectorRef apply_momentum =
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_});
|
||||
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), gradient_});
|
||||
VectorRef apply_momentum = VectorRef(
|
||||
{prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_, monad_state_});
|
||||
return apply_momentum;
|
||||
}
|
||||
|
||||
|
@ -67,17 +68,19 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra
|
|||
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
|
||||
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
|
||||
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
|
||||
MS_EXCEPTION_IF_NULL(weight_decay);
|
||||
MS_EXCEPTION_IF_NULL(variable);
|
||||
MS_EXCEPTION_IF_NULL(accumulation);
|
||||
MS_EXCEPTION_IF_NULL(learning_rate);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(momentum);
|
||||
MS_EXCEPTION_IF_NULL(monad_state);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedWeightApplyMomentum);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, variable, accumulation,
|
||||
learning_rate, gradient, momentum};
|
||||
learning_rate, gradient, momentum, monad_state};
|
||||
auto replace_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
|
|
@ -25,12 +25,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
|
|||
public:
|
||||
explicit ApplyMomentumWeightDecayFusion(bool multigraph = true)
|
||||
: PatternProcessPass("momentum_weightdecay_fusion", multigraph) {
|
||||
monad_ = std::make_shared<Var>();
|
||||
weight_decay_ = std::make_shared<Var>();
|
||||
variable_ = std::make_shared<Var>();
|
||||
accumulation_ = std::make_shared<Var>();
|
||||
learning_rate_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
momentum_ = std::make_shared<Var>();
|
||||
monad_state_ = std::make_shared<Var>();
|
||||
}
|
||||
~ApplyMomentumWeightDecayFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -39,12 +41,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass {
|
|||
private:
|
||||
static bool IsScalar(const BaseRef &n);
|
||||
|
||||
VarPtr monad_;
|
||||
VarPtr weight_decay_;
|
||||
VarPtr variable_;
|
||||
VarPtr accumulation_;
|
||||
VarPtr learning_rate_;
|
||||
VarPtr gradient_;
|
||||
VarPtr momentum_;
|
||||
VarPtr monad_state_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,11 +49,12 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) {
|
|||
}
|
||||
|
||||
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const {
|
||||
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_});
|
||||
VectorRef weight = VectorRef(
|
||||
{prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
|
||||
{prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})});
|
||||
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_});
|
||||
VectorRef apply_momentum =
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_});
|
||||
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_});
|
||||
return apply_momentum;
|
||||
}
|
||||
|
||||
|
@ -69,6 +70,8 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
|
|||
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]);
|
||||
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]);
|
||||
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]);
|
||||
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(weight_decay);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(variable);
|
||||
|
@ -76,11 +79,12 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr
|
|||
MS_EXCEPTION_IF_NULL(learning_rate);
|
||||
MS_EXCEPTION_IF_NULL(gradient);
|
||||
MS_EXCEPTION_IF_NULL(momentum);
|
||||
MS_EXCEPTION_IF_NULL(monad_state);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable,
|
||||
accumulation, learning_rate, gradient, momentum};
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation,
|
||||
learning_rate, gradient, momentum, monad_state};
|
||||
auto replace_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(node, 0)};
|
||||
|
|
|
@ -25,6 +25,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|||
public:
|
||||
explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true)
|
||||
: PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) {
|
||||
monad_ = std::make_shared<Var>();
|
||||
weight_decay_ = std::make_shared<Var>();
|
||||
scale_ = std::make_shared<CondVar>(IsScalar);
|
||||
variable_ = std::make_shared<Var>();
|
||||
|
@ -32,6 +33,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|||
learning_rate_ = std::make_shared<Var>();
|
||||
gradient_ = std::make_shared<Var>();
|
||||
momentum_ = std::make_shared<Var>();
|
||||
monad_state_ = std::make_shared<Var>();
|
||||
}
|
||||
~ApplyMomentumWeightDecayScaleFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -40,6 +42,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|||
private:
|
||||
static bool IsScalar(const BaseRef &n);
|
||||
|
||||
VarPtr monad_;
|
||||
VarPtr weight_decay_;
|
||||
VarPtr scale_;
|
||||
VarPtr variable_;
|
||||
|
@ -47,6 +50,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass {
|
|||
VarPtr learning_rate_;
|
||||
VarPtr gradient_;
|
||||
VarPtr momentum_;
|
||||
VarPtr monad_state_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,11 +37,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
|
|||
for (size_t idx = 0; idx < node_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
|
@ -57,16 +59,39 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
|
|||
bool GetDealList(const std::vector<AnfNodePtr> &node_list, std::vector<std::vector<AnfNodePtr>> *deal_list) {
|
||||
std::vector<AnfNodePtr> cast_32to16_list;
|
||||
std::vector<AnfNodePtr> cast_16to32_list;
|
||||
AnfNodePtr cast_32to16_load_monad = nullptr;
|
||||
AnfNodePtr cast_16to32_load_monad = nullptr;
|
||||
constexpr size_t second_input_index = 2;
|
||||
for (auto &cast_node : node_list) {
|
||||
// currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle.
|
||||
if (cast_node != nullptr && cast_node->isa<CNode>() && AnfAlgo::GetCNodeName(cast_node) == "Cast" &&
|
||||
(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0))->isa<Parameter>()) {
|
||||
auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0);
|
||||
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0);
|
||||
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
|
||||
cast_32to16_list.push_back(cast_node);
|
||||
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
|
||||
cast_16to32_list.push_back(cast_node);
|
||||
// { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }}
|
||||
if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) {
|
||||
auto input0 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(cast_node), 0);
|
||||
if (input0->isa<Parameter>() || (IsPrimitiveCNode(input0, prim::kPrimLoad) &&
|
||||
(AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input0), 0))->isa<Parameter>())) {
|
||||
auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0);
|
||||
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0);
|
||||
if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) {
|
||||
cast_32to16_list.push_back(cast_node);
|
||||
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
|
||||
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
if (cast_32to16_load_monad == nullptr) {
|
||||
cast_32to16_load_monad = monad;
|
||||
} else if (cast_32to16_load_monad != monad) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) {
|
||||
cast_16to32_list.push_back(cast_node);
|
||||
if (IsPrimitiveCNode(input0, prim::kPrimLoad)) {
|
||||
auto &monad = input0->cast<CNodePtr>()->inputs().at(second_input_index);
|
||||
if (cast_16to32_load_monad == nullptr) {
|
||||
cast_16to32_load_monad = monad;
|
||||
} else if (cast_16to32_load_monad != monad) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,11 +36,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector<AnfNodePtr>
|
|||
for (size_t idx = 0; idx < node_list.size(); ++idx) {
|
||||
auto cnode = utils::cast<CNodePtr>(node_list[idx]);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(kOpFormat_DEFAULT);
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
|
||||
|
|
|
@ -53,6 +53,8 @@ std::set<string> kSkipOpNames = {
|
|||
std::map<string, uint32_t> kAggregatesOpNames = {
|
||||
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}};
|
||||
|
||||
constexpr size_t inplace_node_size = 2;
|
||||
|
||||
template <typename T>
|
||||
void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node);
|
||||
|
@ -60,40 +62,103 @@ void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) {
|
|||
primitive->AddAttr(key, MakeValue(value));
|
||||
}
|
||||
|
||||
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node) {
|
||||
std::pair<size_t, bool> GetCoverIndex(const std::vector<AnfNodeIndex> &inplace_node) {
|
||||
if (inplace_node.size() != inplace_node_size) {
|
||||
return {0, false};
|
||||
}
|
||||
auto first_node = inplace_node[0].node;
|
||||
auto second_node = inplace_node[1].node;
|
||||
if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName ||
|
||||
AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) {
|
||||
return {0, false};
|
||||
}
|
||||
|
||||
auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node);
|
||||
auto first_node_channel = first_node_prim.get()->GetAttr("out_channel");
|
||||
MS_EXCEPTION_IF_NULL(first_node_channel);
|
||||
size_t first_channel = first_node_channel->cast<Int64ImmPtr>()->value();
|
||||
auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node);
|
||||
auto second_node_channel = second_node_prim.get()->GetAttr("out_channel");
|
||||
MS_EXCEPTION_IF_NULL(second_node_channel);
|
||||
size_t second_channel = second_node_channel->cast<Int64ImmPtr>()->value();
|
||||
size_t cover_index = (first_channel >= second_channel) ? 0 : 1;
|
||||
return {cover_index, true};
|
||||
}
|
||||
|
||||
void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) {
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get());
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(src);
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i));
|
||||
shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get());
|
||||
}
|
||||
|
||||
void CheckInplaceNodeInputs(std::vector<AnfNodeIndex> *inplace_node, const FuncGraphPtr &graph) {
|
||||
if (inplace_node->size() == inplace_node_size) {
|
||||
auto first_cnode = (*inplace_node)[0].node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(first_cnode);
|
||||
auto first_node_input = first_cnode->input(1);
|
||||
auto second_cnode = (*inplace_node)[1].node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(second_cnode);
|
||||
auto second_node_input = second_cnode->input(1);
|
||||
|
||||
// if two inplace nodes have same input, will be have loop after insert depend
|
||||
// so copy a new input for one of inplace node
|
||||
if (first_node_input == second_node_input) {
|
||||
auto cnode = first_node_input->cast<CNodePtr>();
|
||||
auto new_input = graph->NewCNode(cnode->inputs());
|
||||
new_input->set_abstract(first_node_input->abstract());
|
||||
CopyKernelInfo(first_node_input, new_input);
|
||||
auto new_inplace_node = graph->NewCNode({first_cnode->input(0), new_input, first_cnode->input(2)});
|
||||
new_inplace_node->set_abstract(first_cnode->abstract());
|
||||
CopyKernelInfo(first_cnode, new_inplace_node);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(first_cnode, new_inplace_node);
|
||||
(*inplace_node)[0].node = new_inplace_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector<AnfNodeIndex> *inplace_node,
|
||||
const FuncGraphPtr &graph) {
|
||||
SetPrimAttr(aggregate_node.node, "aggregate", true);
|
||||
SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index);
|
||||
SetPrimAttr(skip_node, "skip", true);
|
||||
|
||||
static uint32_t group = 0;
|
||||
auto [cover_index, order_required] = GetCoverIndex(*inplace_node);
|
||||
if (order_required) {
|
||||
CheckInplaceNodeInputs(inplace_node, graph);
|
||||
}
|
||||
for (size_t i = 0; i < inplace_node->size(); i++) {
|
||||
auto algo = (i == 0) ? "cover" : "accumulation";
|
||||
SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo);
|
||||
SetPrimAttr((*inplace_node)[i].node, "inplace_group", group);
|
||||
SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index);
|
||||
auto algo = (i == cover_index) ? "cover" : "accumulation";
|
||||
auto node = (*inplace_node)[i].node;
|
||||
SetPrimAttr(node, "inplace_algo", algo);
|
||||
SetPrimAttr(node, "inplace_group", group);
|
||||
SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index);
|
||||
// for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover
|
||||
if (order_required && i != cover_index) {
|
||||
auto acc_node = node;
|
||||
auto cover_node = (*inplace_node)[cover_index].node;
|
||||
auto acc_node_input = acc_node->cast<CNodePtr>()->input(1);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
acc_node_input, cover_node};
|
||||
auto depend_node = graph->NewCNode(inputs);
|
||||
depend_node->set_abstract(acc_node_input->abstract());
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(acc_node_input, depend_node);
|
||||
}
|
||||
}
|
||||
group++;
|
||||
}
|
||||
|
||||
void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector<AnfNodeIndex> &inplace_nodes,
|
||||
const AnfNodePtr aggregate_node) {
|
||||
std::vector<AnfNodePtr> inputs1 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
|
||||
inplace_nodes[0].node, inplace_nodes[1].node};
|
||||
auto control_depend_node = graph->NewCNode(inputs1);
|
||||
|
||||
std::vector<AnfNodePtr> inputs2 = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
aggregate_node, control_depend_node};
|
||||
auto depend_node = graph->NewCNode(inputs2);
|
||||
|
||||
auto users = GetRealNodeUsedList(graph, aggregate_node);
|
||||
if (users->size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString();
|
||||
}
|
||||
auto mount_node = users->at(0).first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mount_node);
|
||||
mount_node->set_input(kFirstDataInputIndex, depend_node);
|
||||
}
|
||||
|
||||
bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node,
|
||||
std::vector<AnfNodeIndex> *inplace) {
|
||||
MS_EXCEPTION_IF_NULL(skip_node);
|
||||
|
@ -117,7 +182,8 @@ bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeInde
|
|||
|
||||
auto cnode = (*skip_node)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto inplace_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(*skip_node), i);
|
||||
if (!inplace_node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -187,9 +253,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) {
|
|||
<< "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString()
|
||||
<< std::endl;
|
||||
// 2. Set Node attr
|
||||
SetNodeAttr(aggregate_node, skip_node, &inplace_node);
|
||||
// 3. Set dependence for inplace nodes
|
||||
InsertControlDependToGraph(graph, inplace_node, aggregate_node.node);
|
||||
SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||
return relu;
|
||||
}
|
||||
|
||||
const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto tensor_add = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
return nullptr;
|
||||
}
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
|
||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4);
|
||||
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(z);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
|
||||
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu);
|
||||
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PostBatchNormAddReluFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit PostBatchNormAddReluFusion(bool multigraph = true)
|
||||
: PatternProcessPass("post_batch_norm_add_relu_fusion", multigraph) {
|
||||
x_ = std::make_shared<Var>();
|
||||
scale_ = std::make_shared<Var>();
|
||||
bias_ = std::make_shared<Var>();
|
||||
mean_ = std::make_shared<Var>();
|
||||
var_ = std::make_shared<Var>();
|
||||
index_ = std::make_shared<Var>();
|
||||
z_ = std::make_shared<Var>();
|
||||
}
|
||||
~PostBatchNormAddReluFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr x_;
|
||||
VarPtr scale_;
|
||||
VarPtr bias_;
|
||||
VarPtr mean_;
|
||||
VarPtr var_;
|
||||
VarPtr index_;
|
||||
VarPtr z_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_
|
|
@ -32,9 +32,7 @@ const size_t kReluV2OutputNum = 2;
|
|||
|
||||
CNodePtr GetRelu(const CNodePtr &relu_grad) {
|
||||
MS_EXCEPTION_IF_NULL(relu_grad);
|
||||
if (relu_grad->size() != kReluGradInputNum) {
|
||||
MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size();
|
||||
}
|
||||
CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum);
|
||||
auto relu_anf = relu_grad->input(2);
|
||||
MS_EXCEPTION_IF_NULL(relu_anf);
|
||||
return relu_anf->cast<CNodePtr>();
|
||||
|
@ -47,11 +45,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
std::vector<TypeId> outputs_type;
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
|
||||
inputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index));
|
||||
outputs_format.push_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
|
@ -65,9 +65,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) {
|
|||
CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
if (relu->size() != kReluInputNum) {
|
||||
MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size();
|
||||
}
|
||||
CheckCNodeInputSize(relu, kReluInputTensorNum);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kReluV2OpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), relu->input(1)};
|
||||
|
@ -106,7 +104,8 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad,
|
|||
|
||||
std::vector<TypeId> types;
|
||||
std::vector<std::vector<size_t>> shapes;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i));
|
||||
shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i));
|
||||
}
|
||||
|
|
|
@ -305,52 +305,14 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo
|
|||
user_cnode->set_input(index, depend_cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
|
||||
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) {
|
||||
// Create control depend, first input is composite op, second is user
|
||||
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
|
||||
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
|
||||
main_graph->AddNode(control_depend_cnode);
|
||||
|
||||
// Create depend node to hold new control depend node.
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(patron_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
|
||||
return depend_cnode;
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) {
|
||||
auto mng = main_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(main_graph, true);
|
||||
main_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
AnfNodePtr patron_node;
|
||||
|
||||
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_cnode);
|
||||
auto output_node = return_cnode->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
patron_node = output_cnode->input(kFirstDataInputIndex);
|
||||
} else {
|
||||
patron_node = output_node;
|
||||
}
|
||||
|
||||
auto &user_nodes = mng->node_users()[patron_node];
|
||||
auto user = user_nodes.begin();
|
||||
return std::make_tuple(patron_node, user->first, user->second);
|
||||
}
|
||||
|
||||
void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user,
|
||||
int index) {
|
||||
auto patron_user_cnode = patron_user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(patron_user_cnode);
|
||||
patron_user_cnode->set_input(index, patron_node);
|
||||
CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) {
|
||||
// Insert update_state_node, need mount a monad node.
|
||||
auto u = NewValueNode(kUMonad);
|
||||
u->set_abstract(kUMonad->ToAbstract());
|
||||
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node};
|
||||
auto update_state_cnode = main_graph->NewCNode(update_state_inputs);
|
||||
main_graph->AddNode(update_state_cnode);
|
||||
return update_state_cnode;
|
||||
}
|
||||
|
||||
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
|
||||
|
@ -474,24 +436,21 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs
|
|||
}
|
||||
|
||||
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
|
||||
const AnfNodePtr &broadcast_to_node,
|
||||
const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) {
|
||||
// 1. find users, change getitem index if needed.
|
||||
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
|
||||
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
|
||||
for (const auto &[user_node, index] : reduce_user_nodes) {
|
||||
// 2. set ac output as user's input.
|
||||
// 3. Make sure modified composite node running first.
|
||||
// * To not change the origin node's dependency relation, add ControlDepend and Depend node.
|
||||
// * For Return node and output node, ControlDepend node will change the order of these two node, which will may
|
||||
// main graph running failed. So only add Depend node to meet the need of execute order.
|
||||
if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) {
|
||||
AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index);
|
||||
} else {
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(index, broadcast_to_node);
|
||||
to_process_order_.emplace_back(composite_node, user_node);
|
||||
}
|
||||
// 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect
|
||||
// update_state_node, broadcat_node and load_node to keep order.
|
||||
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node};
|
||||
auto load_node = main_graph->NewCNode(load_inputs);
|
||||
main_graph->AddNode(load_node);
|
||||
auto user_cnode = user_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_cnode);
|
||||
user_cnode->set_input(index, load_node);
|
||||
to_process_order_.emplace_back(composite_node, user_node);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -509,8 +468,11 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
|
|||
// Note: if it's single output, this will increase total memory because of a fake out.
|
||||
ProcessOriginCNode(origin_composite_node, broadcast_to_node, mng);
|
||||
|
||||
// Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user.
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng);
|
||||
// Insert update_state_node to keep execution order.
|
||||
auto update_state_node = InsertUpdateState(main_graph, origin_composite_node);
|
||||
|
||||
// Replace origin ReduceSum's user with atomic clean output
|
||||
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng);
|
||||
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
|
||||
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
|
||||
}
|
||||
|
@ -554,14 +516,6 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
if (changed) {
|
||||
if (!to_process_order_.empty()) {
|
||||
auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph);
|
||||
for (const auto &[prior, behind] : to_process_order_) {
|
||||
patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node);
|
||||
}
|
||||
PostprocessForLastPatron(patron_node, patron_user, user_index);
|
||||
}
|
||||
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
|
|
|
@ -37,9 +37,10 @@ class AtomicCleanInsertter : public Pass {
|
|||
virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
|
||||
virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
|
||||
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &user_node, int index);
|
||||
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
|
||||
CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node);
|
||||
CNodePtr atomic_add_node_{nullptr};
|
||||
|
||||
private:
|
||||
|
@ -48,11 +49,8 @@ class AtomicCleanInsertter : public Pass {
|
|||
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
|
||||
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
|
||||
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
||||
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
|
||||
std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph);
|
||||
AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
|
||||
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node);
|
||||
void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index);
|
||||
const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node,
|
||||
const FuncGraphManagerPtr &mng);
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
|
||||
const AnfNodePtr &composite_node,
|
||||
const FuncGraphManagerPtr &mng, bool correct_index);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -149,9 +149,18 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
|
||||
if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) {
|
||||
if (fuse_nodes.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (fuse_nodes.size() == 1) {
|
||||
// Do not fuse a single GraphKernel again.
|
||||
// Do not fuse a single Assign.
|
||||
if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
changed = true;
|
||||
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
|
||||
AnfNodePtr fused_new_node;
|
||||
|
|
|
@ -109,7 +109,10 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) {
|
|||
// 1. Try to remove redundant depend.
|
||||
bool changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) {
|
||||
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void {
|
||||
if (HasAbstractMonad(node)) {
|
||||
return;
|
||||
}
|
||||
if (RemoveRedundantDepend(node, mng)) {
|
||||
changed = true;
|
||||
}
|
||||
|
@ -126,7 +129,8 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// Find depend and its free nodes.
|
||||
for (const auto &node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
|
||||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -177,6 +177,7 @@ bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
|
|||
std::shared_ptr<Pass> pass = std::make_shared<opt::SubstituteDropout>();
|
||||
pass->Run(func_graph);
|
||||
}
|
||||
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
|
|
|
@ -494,8 +494,8 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
|
|||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimCast, prim::kPrimExpandDims};
|
||||
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimAssign, prim::kPrimExpandDims};
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
|
|
|
@ -629,7 +629,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
|
|||
}
|
||||
GetValidKernelNodes();
|
||||
// call CostModel to get a split plan.
|
||||
if (!SplitByCostModel() || split_plan_.size() != need_inline_.size()) {
|
||||
if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
|
||||
split_plan_.clear();
|
||||
need_inline_.clear();
|
||||
return;
|
||||
|
|
|
@ -103,28 +103,23 @@ bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr ¶m_user)
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr ¶m_user) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user};
|
||||
auto cd_node = func_graph->NewCNode(cd_inputs);
|
||||
func_graph->AddNode(cd_node);
|
||||
return cd_node;
|
||||
}
|
||||
void KeepExecOrder(const FuncGraphPtr &func_graph, const AnfNodePtr &gk_node, const AnfNodePtr &par_user_node,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
// Insert update_state_node, need mount a monad node.
|
||||
auto u = NewValueNode(kUMonad);
|
||||
u->set_abstract(kUMonad->ToAbstract());
|
||||
AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, gk_node};
|
||||
auto update_state_node = func_graph->NewCNode(update_state_inputs);
|
||||
update_state_node->set_abstract(gk_node->abstract());
|
||||
func_graph->AddNode(update_state_node);
|
||||
|
||||
void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto output_tuple = func_graph->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_tuple);
|
||||
auto cur_node = output_tuple->input(1);
|
||||
for (const auto &cd : cd_nodes) {
|
||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd};
|
||||
auto depend_node = func_graph->NewCNode(depend_inputs);
|
||||
depend_node->set_abstract(depend_inputs[1]->abstract());
|
||||
cur_node = depend_node;
|
||||
}
|
||||
mng->Replace(output_tuple->input(1), cur_node);
|
||||
// Insert load_node
|
||||
AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), par_user_node, update_state_node};
|
||||
auto load_node = func_graph->NewCNode(load_inputs);
|
||||
load_node->set_abstract(par_user_node->abstract());
|
||||
func_graph->AddNode(load_node);
|
||||
|
||||
mng->Replace(gk_node, par_user_node);
|
||||
}
|
||||
|
||||
int64_t GetitemIndex(const AnfNodePtr &getitem) {
|
||||
|
@ -133,11 +128,10 @@ int64_t GetitemIndex(const AnfNodePtr &getitem) {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode,
|
||||
const AnfNodePtr &assign_to, int64_t removed_index) {
|
||||
void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, const AnfNodePtr &assign_to,
|
||||
int64_t removed_index) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
AnfNodePtrList cd_nodes;
|
||||
for (const auto &getitem_iter : mng->node_users()[cnode]) {
|
||||
auto getitem = getitem_iter.first;
|
||||
if (GetitemIndex(getitem) != removed_index) continue;
|
||||
|
@ -152,13 +146,10 @@ AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const An
|
|||
if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) {
|
||||
continue;
|
||||
}
|
||||
// keep execution order: cnode -> getitem_user
|
||||
auto cd_node = AddControlDepend(func_graph, getitem, getitem_user);
|
||||
cd_nodes.push_back(cd_node);
|
||||
KeepExecOrder(func_graph, cnode, getitem_user, mng);
|
||||
}
|
||||
break;
|
||||
}
|
||||
return cd_nodes;
|
||||
}
|
||||
|
||||
bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
|
||||
|
@ -166,7 +157,6 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
bool changed = false;
|
||||
AnfNodePtrList control_depend_nodes;
|
||||
for (const auto &n : todos) {
|
||||
if (!AnfAlgo::IsGraphKernel(n)) continue;
|
||||
auto cnode = n->cast<CNodePtr>();
|
||||
|
@ -174,11 +164,9 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) {
|
|||
if (replaceable_nodes.empty()) continue;
|
||||
changed = true;
|
||||
for (const auto &iter : replaceable_nodes) {
|
||||
auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
|
||||
control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end());
|
||||
UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first);
|
||||
}
|
||||
}
|
||||
LinkControlDepends(func_graph, control_depend_nodes);
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,8 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
|
|||
|
||||
void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
|
||||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -118,96 +119,6 @@ void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
|||
ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels);
|
||||
}
|
||||
|
||||
std::tuple<std::pair<AnfNodePtr, AnfNodePtr>, std::pair<AnfNodePtrList, AnfNodePtrList>> FindRelationOfControlDepend(
|
||||
const AnfNodePtr &node, OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto behind_node = cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(behind_node);
|
||||
|
||||
OrderedSet<AnfNodePtr> prior_nodes;
|
||||
prior_nodes.insert(prior_node);
|
||||
OrderedSet<AnfNodePtr> behind_nodes;
|
||||
behind_nodes.insert(behind_node);
|
||||
|
||||
int64_t depend_mode = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||
}
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = (*node_rels)[prior_node].nexts;
|
||||
}
|
||||
if (behind_node->isa<Parameter>()) {
|
||||
behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet<AnfNodePtr>();
|
||||
}
|
||||
|
||||
// Get real nodes.
|
||||
AnfNodePtrList real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
AnfNodePtrList real_behind_nodes;
|
||||
std::set<AnfNodePtr> behind_visited;
|
||||
for (const auto &tmp : behind_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes));
|
||||
}
|
||||
|
||||
void ReLinkNodesOfControlDependByRelation(const std::unordered_map<AnfNodePtr, AnfNodePtrList> &control_depend_info,
|
||||
OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
// Relink and its log.
|
||||
for (const auto &m : control_depend_info) {
|
||||
const auto &prior = m.second[0];
|
||||
const auto &behind = m.second[1];
|
||||
(*node_rels)[prior].nexts.insert(behind);
|
||||
(*node_rels)[behind].pres.insert(prior);
|
||||
MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope()
|
||||
<< " -> " << behind->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessControlDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtrList> control_depend_info;
|
||||
AnfNodePtrList latter_to_be_erased;
|
||||
|
||||
// Collect ControlDepend node and its input and output nodes.
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels);
|
||||
auto &[prior_node, behind_node] = direct_relation;
|
||||
auto &[real_prior_nodes, real_behind_nodes] = real_relations;
|
||||
|
||||
(*node_rels)[prior_node].nexts.erase(node);
|
||||
(*node_rels)[behind_node].nexts.erase(node);
|
||||
node_rel.pres.erase(prior_node);
|
||||
node_rel.pres.erase(behind_node);
|
||||
|
||||
for (auto &first_node : real_prior_nodes) {
|
||||
for (auto &second_node : real_behind_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
control_depend_info.insert({node, {first_node, second_node}});
|
||||
}
|
||||
}
|
||||
latter_to_be_erased.push_back(node);
|
||||
}
|
||||
|
||||
// Delete ControlDepend node before relink its relation.
|
||||
for (const auto &node : latter_to_be_erased) {
|
||||
node_rels->erase(node);
|
||||
}
|
||||
|
||||
// Rebuild relation between prior and behind node.
|
||||
ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels);
|
||||
}
|
||||
|
||||
void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
AnfNodePtrList latter_to_be_erased;
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
|
@ -538,7 +449,6 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
|
|||
}
|
||||
|
||||
ProcessDependCNode(&node_rels);
|
||||
ProcessControlDependCNode(&node_rels);
|
||||
ProcessThroughPassCNode(
|
||||
[](const AnfNodePtr &node) {
|
||||
return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/split_assign.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
const BaseRef SplitAssign::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<Var>();
|
||||
VarPtr Us = std::make_shared<Var>();
|
||||
VarPtr UMonad = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimAssign, Xs, Us, UMonad});
|
||||
}
|
||||
|
||||
const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CheckCNodeInputSize(cnode, kAssignInputTensorNum);
|
||||
// Get original assign op's abstract and inputs
|
||||
AbstractBasePtr original_abstract = cnode->abstract()->Clone();
|
||||
auto original_inputs = cnode->inputs();
|
||||
// Create depend node
|
||||
AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]};
|
||||
auto depend_cnode = func_graph->NewCNode(depend_inputs);
|
||||
depend_cnode->set_abstract(original_inputs[1]->abstract());
|
||||
depend_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
// Create new assign node, delete U from inputs.
|
||||
AnfNodePtrList new_assign_inputs = {NewValueNode(prim::kPrimAssign), depend_cnode, original_inputs[2]};
|
||||
auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs);
|
||||
new_assign_cnode->set_abstract(original_abstract);
|
||||
new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr());
|
||||
return new_assign_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SplitAssign : public PatternProcessPass {
|
||||
public:
|
||||
explicit SplitAssign(bool multigraph = true) : PatternProcessPass("split_assign", multigraph) {}
|
||||
~SplitAssign() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_
|
|
@ -41,13 +41,15 @@ const BaseRef SubstituteDropout::DefinePattern() const {
|
|||
void SetNewKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> inputs_format;
|
||||
std::vector<TypeId> inputs_type;
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
|
||||
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index));
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
std::vector<TypeId> outputs_type;
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||
}
|
||||
|
@ -69,15 +71,13 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() < kDropoutInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Dropout's input num is wrong";
|
||||
}
|
||||
CheckCNodeInputSize(cnode, kDropoutInputTensorNum);
|
||||
AbstractBasePtr old_abstract = cnode->abstract()->Clone();
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
|
||||
ShapeVector shape_i64;
|
||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); });
|
||||
|
||||
// The primitive should use a clone, otherwise the attr seed will be overrode.
|
||||
// The primitive should use a clone, otherwise the attr seed will be overridden.
|
||||
AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())};
|
||||
auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())),
|
||||
static_cast<void *>(&shape[0]), kNumberTypeInt64);
|
||||
|
|
|
@ -249,7 +249,8 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) {
|
|||
if (node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The node pointer is a nullptr.";
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
// Get ref count for cnode, except monad cnode.
|
||||
if (node->isa<CNode>() && !HasAbstractMonad(node)) {
|
||||
auto ak_node = node->cast<CNodePtr>();
|
||||
auto key = ak_node.get();
|
||||
MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx));
|
||||
|
@ -314,7 +315,8 @@ void MemReuseUtil::SetKernelDefInputs() {
|
|||
MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init.";
|
||||
}
|
||||
auto kernel_def = iter->second;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto ref_ptr = GetKernelInputRef(kernel, i);
|
||||
if (ref_ptr != nullptr) {
|
||||
// set the inputs of this kernel_def
|
||||
|
|
|
@ -214,7 +214,8 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph
|
|||
// set real graph output node to be special who's refcount equal kMaxRefCount
|
||||
for (const auto &output : graph->outputs()) {
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(output);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (output->isa<CNode>()) {
|
||||
auto cnode = output->cast<CNodePtr>();
|
||||
auto input_node = cnode->input(i + 1);
|
||||
|
@ -364,7 +365,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
|
|||
const auto &cnodes = graph->execution_order();
|
||||
for (const auto &node : cnodes) {
|
||||
std::vector<const void *> curr_ous;
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto it = AnfAlgo::GetOutputAddr(node, i);
|
||||
MS_EXCEPTION_IF_NULL(it);
|
||||
auto ptr = it->GetPtr();
|
||||
|
@ -374,7 +376,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
|
|||
}
|
||||
(void)node_ous_.insert(std::make_pair(node.get(), curr_ous));
|
||||
std::vector<const void *> curr_ins;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
if (i + 1 >= node->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index: " << i
|
||||
<< " is larger than input number: " << AnfAlgo::GetInputTensorNum(node);
|
||||
|
|
|
@ -37,7 +37,8 @@ bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, s
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
|
||||
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
|
||||
for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
|
||||
TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx};
|
||||
ordered_tensors_.push_back(tensor_info);
|
||||
}
|
||||
|
|
|
@ -51,12 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co
|
|||
rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
||||
inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index));
|
||||
}
|
||||
for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) {
|
||||
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index));
|
||||
outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index));
|
||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(cnode, output_index);
|
||||
|
@ -170,6 +172,117 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|||
return CheckSegments(segments, communication_op_node_size, segment_index);
|
||||
}
|
||||
|
||||
// Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent
|
||||
// cycle after AllReduce fused. It's a workaround.
|
||||
// case 1:
|
||||
// cnode_load = Load(%para2, cnode_u)
|
||||
// %100 = UpdateState(cnode_u, cnode_load)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), %100)
|
||||
// %110 = UpdateState(%100, xxx)
|
||||
// will convert to:
|
||||
// cnode_load = Load(%para2, U)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), cnode_u)
|
||||
// %110 = UpdateState(cnode_u, xxx)
|
||||
//
|
||||
// case 2:
|
||||
// cnode_load = Load(%para2, cnode_u)
|
||||
// %99 = make_tuple(yyy, ..., cnode_load, ...)
|
||||
// %100 = UpdateState(cnode_u, %99)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), %100)
|
||||
// %110 = UpdateState(%100, xxx)
|
||||
// will convert to:
|
||||
// cnode_load = Load(%para2, U)
|
||||
// %99 = make_tuple(yyy, ...)
|
||||
// %100 = UpdateState(cnode_u, %99)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), %100)
|
||||
// %110 = UpdateState(%100, xxx)
|
||||
//
|
||||
// case 3:
|
||||
// cnode_load = Load(%para2, cnode_u)
|
||||
// %99 = make_tuple(cnode_load)
|
||||
// %100 = UpdateState(cnode_u, %99)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), %100)
|
||||
// %110 = UpdateState(%100, xxx)
|
||||
// will convert to:
|
||||
// cnode_load = Load(%para2, U)
|
||||
// ...
|
||||
// %109 = AssignAdd(%para485, Tensor(34), cnode_u)
|
||||
// %110 = UpdateState(cnode_u, xxx)
|
||||
static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) {
|
||||
auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) {
|
||||
if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) {
|
||||
return false;
|
||||
}
|
||||
if (search_cnode->inputs().size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString();
|
||||
}
|
||||
return search_cnode->input(2)->isa<CNode>();
|
||||
});
|
||||
if (cnode_load != nullptr) {
|
||||
const auto &const_u_monad = NewValueNode(kUMonad);
|
||||
const auto &cnode_u = cnode_load->input(2);
|
||||
MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString();
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager());
|
||||
auto manager = cnode->func_graph()->manager();
|
||||
manager->SetEdge(cnode_load, 2, const_u_monad);
|
||||
// Update the u_monad input of UpdateState from CNode U same as Load to constant U.
|
||||
CNodePtr cnode_update_state = nullptr;
|
||||
CNodePtr cnode_make_tuple = nullptr;
|
||||
const auto &cnode_load_users = manager->node_users()[cnode_load];
|
||||
for (auto &load_user : cnode_load_users) {
|
||||
if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
|
||||
const auto &cnode_make_tuple_users = manager->node_users()[load_user.first];
|
||||
for (auto &make_tuple_user : cnode_make_tuple_users) {
|
||||
if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) {
|
||||
const auto &cnode_user = make_tuple_user.first->cast<CNodePtr>();
|
||||
if (cnode_user->input(1) == cnode_u) {
|
||||
cnode_update_state = cnode_user;
|
||||
cnode_make_tuple = load_user.first->cast<CNodePtr>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cnode_update_state != nullptr) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
|
||||
const auto &cnode_user = load_user.first->cast<CNodePtr>();
|
||||
if (cnode_user->input(1) == cnode_u) {
|
||||
cnode_update_state = cnode_user;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cnode_update_state != nullptr) {
|
||||
if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == 2) {
|
||||
// case 1 and case 3: Replace cnode_update_state to cnode_u;
|
||||
MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString()
|
||||
<< " ::TO:: " << cnode_u->DebugString();
|
||||
manager->Replace(cnode_update_state, cnode_u);
|
||||
} else if (cnode_make_tuple->inputs().size() > 2) {
|
||||
// case 2: remove cnode_load from cnode_make_tuple;
|
||||
MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString();
|
||||
const auto &make_tuple_inputs = cnode_make_tuple->inputs();
|
||||
AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1);
|
||||
std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(),
|
||||
[cnode_load](const auto &inp) { return inp != cnode_load; });
|
||||
auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs);
|
||||
manager->Replace(cnode_make_tuple, new_cnode_make_tuple);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString()
|
||||
<< " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
|
||||
const CommunicationOpInfo &communication_op_info,
|
||||
size_t start_index, size_t end_index) const {
|
||||
|
@ -184,6 +297,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
||||
auto cnode = communication_op_info.communication_op_nodes[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (idx != start_index) {
|
||||
AdjustAllReduceInputWithLoad(cnode);
|
||||
}
|
||||
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
}
|
||||
CheckInputs(fusion_inputs);
|
||||
|
|
|
@ -107,9 +107,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
|||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
|
||||
|
||||
for (auto &t : todo) {
|
||||
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());
|
||||
|
|
|
@ -37,7 +37,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|||
std::vector<AnfNodePtr> plant_inputs;
|
||||
std::vector<int64_t> dyn_input_sizes;
|
||||
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cnode_ptr);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
|
||||
|
@ -45,7 +46,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|||
dyn_input_sizes.push_back(input_size);
|
||||
auto make_tuple = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
|
||||
size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple);
|
||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
|
||||
|
|
|
@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
|
||||
return nullptr;
|
||||
}
|
||||
bool cnode_input_changed = false;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue