Init GraphKernel.

- It provides a unified style to express graph and kernel for user.
- It provides a unified IR to represent graph and kernel for developer.
- It breaks the boundary between graph and kernel.
- It provides more opportunities to do compile optimization.
This commit is contained in:
gong chen 2020-06-04 09:59:26 +08:00 committed by Xian Weizhao
parent 01216a9a57
commit a6dfa281ea
232 changed files with 12791 additions and 624 deletions

3
.gitmodules vendored
View File

@ -13,3 +13,6 @@
[submodule "graphengine"] [submodule "graphengine"]
path = graphengine path = graphengine
url = https://gitee.com/mindspore/graphengine.git url = https://gitee.com/mindspore/graphengine.git
[submodule "akg"]
path = akg
url = https://gitee.com/mindspore/akg.git

View File

@ -86,10 +86,14 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
endif() endif()
if (ENABLE_AKG AND ENABLE_D)
add_subdirectory("${CMAKE_SOURCE_DIR}/akg")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
add_subdirectory(mindspore/ccsrc) add_subdirectory(mindspore/ccsrc)
if (ENABLE_TESTCASES) if (ENABLE_TESTCASES)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
include(cmake/package.cmake) include(cmake/package.cmake)

1
akg Submodule

@ -0,0 +1 @@
Subproject commit c460176523d039c8995f1d71089753725ebc0792

View File

@ -246,6 +246,9 @@ checkopts "$@"
echo "---------------- mindspore: build start ----------------" echo "---------------- mindspore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib" mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
git submodule update --init --recursive akg
fi
build_exit() build_exit()
{ {
@ -308,7 +311,7 @@ build_mindspore()
if [[ "X$USE_GLOG" = "Xon" ]]; then if [[ "X$USE_GLOG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
fi fi
if [[ "X$ENABLE_AKG" = "Xon" ]]; then if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
fi fi
echo "${CMAKE_ARGS}" echo "${CMAKE_ARGS}"

View File

@ -236,6 +236,16 @@ if (ENABLE_GPU)
endif () endif ()
endif () endif ()
if (ENABLE_D AND ENABLE_AKG)
set (AKG_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg)
install(
DIRECTORY
${AKG_PATH}/akg
DESTINATION ${INSTALL_PY_DIR}/..
COMPONENT mindspore
)
endif ()
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
install( install(
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset

View File

@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,35 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing akg compile with json"""
import sys
def run_compiler(op_json):
"""
Run AKG compiler to compile op with subprocess, if this process of
compilation failed, an exception will be raised
Args:
op_json (str): json string of the op
Returns:
None
"""
p = __import__("akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
res = func(op_json)
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
run_compiler(sys.argv[1])

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import os
import subprocess
import sys
from multiprocessing import Pool, cpu_count
def _compile_akg_task(*json_strs):
"""
compile func called in single process
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
"""
akg_compiler = os.path.join(os.path.split(
os.path.realpath(__file__))[0], "compiler.py")
for json_str in json_strs:
res = subprocess.run(
[sys.executable, akg_compiler, json_str], text=True)
if res.returncode != 0:
raise ValueError("Failed, args: {}!".format(json_str))
def compile_akg_kernel_parallel(json_infos, process, waitime):
"""
compile kernel use multi processes
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
Returns:
True for all compile success, False for some failed.
"""
if not isinstance(json_infos, list):
raise ValueError("json_infos must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
if process == 0 and json_infos:
process = 1
cpu_proc_num = cpu_count()
max_proc_num = 16
process = min([cpu_proc_num, max_proc_num, process])
args = [[] for _ in range(process)]
for p, info in enumerate(json_infos):
args[p % process].append(info)
with Pool(processes=process) as pool:
res = pool.starmap_async(_compile_akg_task, args)
res.get(timeout=waitime)
return True

View File

@ -1,107 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
import json
import math
import os
import subprocess
import sys
from multiprocessing import Pool
def _compiletask(platform, *jsons):
"""
compile func called in single process
Parameters:
platform: str. AKG platform or TBE platform
*jsons: str. json str contain kernel info, suitable for json compile
api
"""
if platform == "AKG":
p = __import__("_akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
for json_item in jsons:
res = func(json_item)
if not res:
raise ValueError("Compile error")
if platform == "TBE":
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
for json_item in jsons:
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
if res.returncode != 0:
raise ValueError("Tbe compile error")
def compilekernelparallel(jsons, process, waitime):
"""
compile kernel use multi processes
Parameters:
jsons: list. json str list contain kernel info
process: int. processes num
waittime: int. max time the function blocked
"""
if not isinstance(jsons, list):
raise ValueError("jsons must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
jsons_akg = []
jsons_tbe = []
for json_ in jsons:
j = json.loads(json_)
if j["platform"] == "TBE":
jsons_tbe.append(json_)
continue
if j["platform"] == "AKG":
jsons_akg.append(json_)
continue
raise RuntimeError(
"not support this platform {0}".format(j["platform"]))
if jsons_akg:
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
else:
process_akg = 0
if process_akg == 0 and jsons_akg:
process_akg = 1
process_tbe = process-process_akg
if process_tbe == 0 and jsons_tbe:
process_tbe = 1
raise RuntimeWarning("we add a process for compile more operator")
args = [[] for _ in range(process_akg+process_tbe)]
args_lens = len(args)
for p in range(args_lens):
if p < process_tbe:
args[p].append("TBE")
else:
args[p].append("AKG")
jsons_tbe_lens = len(jsons_tbe)
for p in range(jsons_tbe_lens):
args[p % process_tbe].append(jsons_tbe[p])
jsons_akg_lens = len(jsons_akg)
for p in range(jsons_akg_lens):
args[process-p % process_akg-1].append(jsons_akg[p])
for p in range(args_lens):
args[p] = tuple(args[p])
with Pool(processes=process) as pool:
res = pool.starmap_async(_compiletask, args)
res.get(timeout=waitime)
return True

View File

@ -39,7 +39,7 @@ if(ENABLE_GPU)
"device/gpu/*.cu" "device/gpu/*.cu"
"kernel/gpu/*.cu" "kernel/gpu/*.cu"
"kernel/akg/gpu/*.cc" "kernel/akg/gpu/*.cc"
"kernel/akg/akgkernelbuild.cc" "kernel/akg/akg_kernel_build.cc"
"kernel/akg/akg_kernel_attrs_process.cc" "kernel/akg/akg_kernel_attrs_process.cc"
) )

View File

@ -428,6 +428,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
auto temp_shape = shape; auto temp_shape = shape;
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) { if (format == kOpFormat_FRAC_NZ) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
if (shape.size() < 2) { if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else { } else {

View File

@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
} }
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
buffer << "#flags :" << std::endl; buffer << "#attrs :" << std::endl;
for (const auto &flag : graph->flags()) { for (const auto &attr : graph->attrs()) {
buffer << flag.first << " : " << flag.second << std::endl; buffer << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
buffer << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
buffer << GetValue<std::string>(attr.second);
}
buffer << std::endl;
} }
} }
@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
fout << std::endl; fout << std::endl;
for (const auto &sg : *sub_graphs) { for (const auto &sg : *sub_graphs) {
fout << "subgraph flag:" << std::endl; fout << "subgraph attr:" << std::endl;
MS_EXCEPTION_IF_NULL(sg.first); MS_EXCEPTION_IF_NULL(sg.first);
for (const auto &flag : sg.first->flags()) { for (const auto &attr : sg.first->attrs()) {
fout << flag.first << " : " << flag.second << std::endl; fout << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
fout << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
fout << GetValue<std::string>(attr.second);
}
fout << std::endl;
} }
fout << "subgraph @" << sg.first->ToString() << "."; fout << "subgraph @" << sg.first->ToString() << ".";
fout << sg.first->debug_info()->get_id() << "("; fout << sg.first->debug_info()->get_id() << "(";

View File

@ -548,9 +548,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i]; cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
ValuePtr value_ptr = nullptr;
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive); if (primitive != nullptr) {
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
} else {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(func_graph);
value_ptr = func_graph->get_attr(kStreamNeedActivedFirst);
}
if (value_ptr == nullptr) { if (value_ptr == nullptr) {
continue; continue;
} }

View File

@ -26,10 +26,12 @@
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/tbe/tbe_kernel_build.h" #include "kernel/tbe/tbe_kernel_build.h"
#include "kernel/tbe/tbe_kernel_parallel_build.h" #include "kernel/tbe/tbe_kernel_parallel_build.h"
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
#include "kernel/aicpu/aicpu_kernel_build.h" #include "kernel/aicpu/aicpu_kernel_build.h"
#include "kernel/hccl/hccl_kernel_build.h" #include "kernel/hccl/hccl_kernel_build.h"
#include "kernel/rts/rt_kernel_build.h" #include "kernel/rts/rt_kernel_build.h"
#include "kernel/tbe/tbe_utils.h" #include "kernel/tbe/tbe_utils.h"
#include "kernel/common_utils.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "./common.h" #include "./common.h"
@ -91,6 +93,7 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
std::vector<AnfNodePtr> tbe_nodes; std::vector<AnfNodePtr> tbe_nodes;
std::vector<AnfNodePtr> akg_nodes;
std::vector<AnfNodePtr> other_nodes; std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernel_graph_ptr->execution_order()) { for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
@ -105,19 +108,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke
} }
break; break;
} }
case KernelType::AKG_KERNEL: {
akg_nodes.push_back(anf_node);
break;
}
default: { default: {
other_nodes.push_back(anf_node); other_nodes.push_back(anf_node);
break; break;
} }
} }
} }
bool ret = kernel::TbeOpParallelBuild(tbe_nodes); bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes);
auto bin_map = kernel::tbe::KernelMeta::GetInstance();
(void)bin_map->ReadIndex(kernel::kCceKernelMeta);
for (const auto &anf_node : other_nodes) { for (const auto &anf_node : other_nodes) {
kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
} }
return ret; return tbe_ret && akg_ret;
} }
static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
@ -234,7 +244,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
for (const auto &anf_node : kernel_graph->execution_order()) { for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
if (apply_function_name == prim::kPrimMaxPoolGrad->name() && if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) { AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim); MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim); auto new_value_node = NewValueNode(clear_zero_prim);

View File

@ -15,16 +15,27 @@
*/ */
#include "device/ascend/kernel_select_ascend.h" #include "device/ascend/kernel_select_ascend.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <algorithm>
#include <map> #include <map>
#include "kernel/oplib/oplib.h" #include <unordered_map>
#include "kernel/kernel_query.h" #include <unordered_set>
#include "session/anf_runtime_algorithm.h"
#include "utils/context/ms_context.h" #include "common/utils.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "operator/ops.h"
#include "ir/func_graph.h"
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/common_utils.h"
#include "kernel/kernel_query.h"
#include "kernel/oplib/oplib.h"
#include "kernel/kernel_build_info.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
@ -121,12 +132,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
} }
auto pri_match_format = GetPriorityMatchFormat(kernel_node); auto pri_match_format = GetPriorityMatchFormat(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_anf_node = kernel_node->input(input_index + 1);
// we do not take ValueNode into consideration in graph kernel.
if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
continue;
}
}
auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
} }
if (kernel_build_info.GetInputDeviceType(input_index) == // we match output fix precision first.
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
if (prev_device_type == kTypeUnknown) {
prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
} }
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
@ -146,41 +168,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
} }
} }
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
// we set special device info of a input tensor.
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetInputDeviceDataType(kernel_node, input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
MS_EXCEPTION_IF_NULL(support_index); MS_EXCEPTION_IF_NULL(support_index);
int index = kUnSupportMixedDataTypeIndex; int index = kUnSupportMixedDataTypeIndex;
@ -467,6 +454,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
} }
} // namespace } // namespace
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
@ -498,11 +530,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
return select_status; return select_status;
} }
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list); if (AnfAlgo::IsGraphKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
MS_EXCEPTION_IF_NULL(func_graph);
SelectGraphKernelInfo(kernel_node, func_graph);
return kStatusAllMatched;
}
kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it // If aicore not find valid kernel info reloading aicpu kernel info list to find it
if (select_status == kNoMatched) { if (select_status == kNoMatched) {

View File

@ -27,7 +27,10 @@ enum KernelSelectStatus {
kStatusReducePrecision = 1, kStatusReducePrecision = 1,
kStatusRaisePrecision = 2, kStatusRaisePrecision = 2,
}; };
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node);
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph);
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,516 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "device/ascend/kernel_select_ascend.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "ir/func_graph.h"
#include "kernel/common_utils.h"
#include "kernel/kernel_query.h"
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace device {
namespace ascend {
TypeId GetPrimitivePrecision(const CNodePtr &cnode) {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
TypeId except_type = kTypeUnknown;
if (primitive->GetAttr(kAttrFixPrecision) != nullptr) {
auto strExceptDtype = GetValue<std::string>(primitive->GetAttr(kAttrFixPrecision));
if (strExceptDtype == "float16") {
except_type = kNumberTypeFloat16;
} else if (strExceptDtype == "float32") {
except_type = kNumberTypeFloat32;
} else {
MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype;
}
}
return except_type;
}
void ResetKernelBuildInfo(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
if (!kernel::IsWeightBoundary(kernel_with_index.first)) {
continue;
}
// reset format and dtype.
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get());
}
}
void UpdateKernelInfo(const std::vector<AnfNodePtr> &node_list) {
for (size_t i = 0; i < node_list.size(); ++i) {
// select nodes in subgraph.
auto anf_node = node_list[i];
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto fix_precision_type = GetPrimitivePrecision(cnode);
if (fix_precision_type != kTypeUnknown) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL);
for (size_t index = 0; index < kernel_info_list.size(); ++index)
// only math the first input
if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type &&
kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) &&
AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) {
auto selected_kernel_info_ptr = kernel_info_list[index];
ResetKernelBuildInfo(cnode);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get());
SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode);
break;
}
}
}
}
bool CanConvertDefaultShapeToNZ(const std::vector<size_t> &shape) {
for (size_t i = 1; i <= shape.size(); ++i) {
if (i > 2) {
break;
}
if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) {
return false;
}
}
return true;
}
std::vector<int> DefaultToFracNZAxis(const std::vector<size_t> &ori_shape, const std::vector<int> &axis) {
std::vector<int> frac_nz_axis = axis;
auto shape_len = ori_shape.size();
for (size_t i = 0; i < axis.size(); ++i) {
auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len;
if (axis_idx == shape_len - 1) {
frac_nz_axis[i] = axis_idx - 1;
frac_nz_axis.push_back(axis_idx + 2);
} else if (axis_idx == shape_len - 2) {
frac_nz_axis[i] = axis_idx + 1;
frac_nz_axis.push_back(axis_idx + 2);
} else {
frac_nz_axis[i] = axis_idx;
}
}
return frac_nz_axis;
}
std::vector<size_t> GetReducedFracNZShape(const std::vector<size_t> &ori_shape, const std::vector<int> &axis,
bool keep_dims) {
std::vector<size_t> result;
std::set<size_t> positive_idx;
for (const auto &a : axis) {
positive_idx.insert(a >= 0 ? a : ori_shape.size() + a);
}
for (size_t i = 0; i < ori_shape.size(); ++i) {
if (positive_idx.count(i) == 0) {
result.push_back(ori_shape[i]);
} else if (keep_dims) {
result.push_back(1);
}
}
return result;
}
void UpdateFracNZReduceOp(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
if (input_format == kOpFormat_FRAC_NZ) {
// Clone primitive to modify it
auto prim = GetCNodePrimitive(cnode);
auto new_prim = std::make_shared<Primitive>(*prim);
auto new_prim_node = NewValueNode(new_prim);
cnode->set_input(0, new_prim_node);
auto axis_value = new_prim->GetAttr(kAttrAxis);
std::vector<int> default_axis;
if (axis_value->isa<ValueList>()) {
auto value_list = dyn_cast<ValueList>(axis_value);
for (const auto &item : value_list->value()) {
if (item->isa<Int32Imm>()) {
default_axis.push_back(GetValue<int32_t>(item));
}
}
} else if (axis_value->isa<ValueTuple>()) {
auto value_tuple = dyn_cast<ValueTuple>(axis_value);
for (const auto &item : value_tuple->value()) {
if (item->isa<Int32Imm>()) {
default_axis.push_back(GetValue<int32_t>(item));
}
}
} else {
MS_LOG(ERROR) << "Axis attr type is not correct!";
}
auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
std::vector<int> frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<std::vector<int>>(frac_nz_axis), cnode);
auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
if (output_shape.size() == 1) {
AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue<bool>(true), cnode);
}
}
}
void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(default_format);
MS_EXCEPTION_IF_NULL(use_same_format);
std::unordered_map<std::string, size_t> all_input_formats;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; ++i) {
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>()) {
auto pre_format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i);
++all_input_formats[pre_format];
continue;
}
auto para = input_kernel_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
auto pre_format = AnfAlgo::GetOutputFormat(para, 0);
++all_input_formats[pre_format];
continue;
}
*use_same_format = false;
}
if (all_input_formats.empty()) {
// all inputs are parameter.
*default_format = kOpFormat_NC1HWC0;
} else {
std::vector<std::pair<std::string, size_t>> pairs;
for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) {
pairs.push_back(std::make_pair(iter->first, iter->second));
}
auto cmp_func = [](const std::pair<std::string, size_t> &a, const std::pair<std::string, size_t> &b) {
if (a.second != b.second) {
return a.second > b.second;
} else if (a.first == kOpFormat_DEFAULT) {
return a.second + 1 > b.second;
} else if (b.first == kOpFormat_DEFAULT) {
return a.second > b.second + 1;
}
return a.second > b.second;
};
std::sort(pairs.begin(), pairs.end(), cmp_func);
*default_format = pairs.begin()->first;
}
for (size_t i = 0; i < input_num; ++i) {
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>() ||
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) {
continue;
}
auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0);
if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) {
*default_format = kOpFormat_DEFAULT;
*use_same_format = true;
break;
}
}
}
void UpdateGraphKernelInputsKernelInfo(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &input_list,
const std::string &default_format, bool use_same_format,
std::vector<std::string> *graph_input_format,
std::vector<TypeId> *graph_input_type) {
MS_EXCEPTION_IF_NULL(graph_input_format);
MS_EXCEPTION_IF_NULL(graph_input_type);
// We set same format to all inputs of graph kernel subgraph, and process this latter.
// We set dtype to inputs of graph kernel subgraph same as infer dtypes.
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; ++i) {
auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (use_same_format) {
bool can_convert = true;
if (default_format == kOpFormat_FRAC_NZ) {
auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
if (!CanConvertDefaultShapeToNZ(infer_shape)) {
MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead";
can_convert = false;
}
}
if (can_convert) {
graph_input_format->push_back(default_format);
} else {
graph_input_format->push_back(kOpFormat_DEFAULT);
}
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
continue;
}
if (!input_kernel_node->isa<Parameter>()) {
// subgraph parameter from output of other nodes.
graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i));
graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i));
continue;
}
auto para = input_kernel_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) {
// parameter already selected.
graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0));
graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0));
continue;
}
// weight parameter.
graph_input_format->push_back(default_format);
graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0));
}
for (size_t i = 0; i < input_num; ++i) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
}
}
void UpdateEquivFormat(const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
const std::vector<AnfNodePtr> &node_list, const FuncGraphPtr &func_graph,
const FuncGraphManagerPtr &mng) {
MS_EXCEPTION_IF_NULL(mng);
for (size_t i = 0; i < node_list.size(); ++i) {
// select nodes in subgraph.
auto anf_node = node_list[i];
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
SelectKernelInfo(cnode, KernelType::AKG_KERNEL);
// Update ReduceSum
if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) {
continue;
}
UpdateFracNZReduceOp(cnode);
// If ReduceSum's output is 1d and not Default format, convert it to Default format
auto out_format = AnfAlgo::GetOutputFormat(cnode, 0);
if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) {
continue;
}
auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
// Insert EquivFormat node, then select kernel info again
std::vector<AnfNodePtr> trans_inputs;
trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat));
trans_inputs.push_back(cnode);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)},
{AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get());
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue<std::vector<std::string>>({"x"}), trans_node);
if (trans_node->kernel_info() == nullptr) {
trans_node->set_kernel_info(std::make_shared<device::KernelInfo>());
}
SelectKernelInfo(trans_node, KernelType::AKG_KERNEL);
mng->Replace(cnode, trans_node);
}
}
void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list, const FuncGraphManagerPtr &mng,
const std::string &default_format, std::vector<std::string> *graph_input_format,
std::vector<TypeId> *graph_input_type) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(graph_input_format);
MS_EXCEPTION_IF_NULL(graph_input_type);
// update graph input format and dtype use inner ops.
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (graph_input_format->size() != input_num) {
MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString()
<< "], [%" << graph_input_format->size() << "] != [%" << input_num << "]";
}
std::vector<bool> need_update(input_num, false);
auto &node_users = mng->node_users();
for (size_t i = 0; i < input_num; ++i) {
auto &input = input_list[i];
auto iter = node_users.find(input);
if (iter == node_users.end() || iter->second.empty()) {
continue;
}
for (auto &node_user : iter->second) {
if (node_user.first->kernel_info() == nullptr ||
node_user.first->kernel_info()->select_kernel_build_info() == nullptr) {
// maybe not a real kernel.
continue;
}
auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1));
if (user_format != (*graph_input_format)[i]) {
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
<< kernel_node->DebugString()
<< "] selected different format. we use defult: " << default_format;
(*graph_input_format)[i] = default_format;
need_update[i] = true;
}
if (kernel_node->input(i + 1)->isa<Parameter>()) {
auto user_dtype = AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1));
if (user_dtype != (*graph_input_type)[i]) {
TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0);
MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of ["
<< kernel_node->DebugString()
<< "] selected different dtype. we use default: " << TypeIdLabel(default_dtype);
(*graph_input_type)[i] = default_dtype;
need_update[i] = true;
}
}
}
}
for (size_t i = 0; i < input_num; ++i) {
if (!need_update[i]) {
continue;
}
need_update[i] = false;
MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString()
<< "] to: " << (*graph_input_format)[i];
MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString()
<< "] to: " << TypeIdLabel((*graph_input_type)[i]);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
std::vector<std::string> outputs_format = {(*graph_input_format)[i]};
std::vector<TypeId> outputs_device_type = {(*graph_input_type)[i]};
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_device_type);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get());
}
ResetKernelBuildInfo(kernel_node);
// select nodes in subgraph again.
for (size_t i = 0; i < node_list.size(); ++i) {
auto anf_node = node_list[i];
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t j = 0; j < cnode_input_num; ++j) {
auto input_node = cnode->input(j + 1);
MS_EXCEPTION_IF_NULL(input_node);
if (!IsValueNode<tensor::Tensor>(input_node)) {
continue;
}
// reset format and dtype of const tensor.
builder.SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
builder.SetOutputsDeviceType(std::vector<TypeId>{kTypeUnknown});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get());
}
SelectKernelInfo(node_list[i]->cast<CNodePtr>(), KernelType::AKG_KERNEL);
}
}
void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair<AnfNodePtr, size_t>> &output_index,
const std::vector<std::string> &graph_input_format,
const std::vector<TypeId> &graph_input_type) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
for (size_t i = 0; i < output_index.size(); ++i) {
auto const &output = output_index[i];
graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second));
TypeId output_type(kTypeUnknown);
if (output.first->isa<CNode>()) {
output_type = AnfAlgo::GetCNodeOutputPrecision(output.first);
}
if (output_type == kTypeUnknown) {
output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second);
}
graph_output_type.push_back(output_type);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
MS_EXCEPTION_IF_NULL(graph_selected_info);
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
SetTensorDeviceInfo(*graph_selected_info, kernel_node);
}
void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(func_graph);
// collect input info of funcgraph
std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
if (input_list.size() != kernel_node->inputs().size() - 1) {
MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode["
<< kernel_node->DebugString() << "], [%" << input_list.size() << "] != ["
<< kernel_node->inputs().size() << "]";
}
std::string default_format;
bool use_same_format = true;
GetDefaultFormat(kernel_node, &default_format, &use_same_format);
MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format
<< "] for ParameterWeight.";
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
UpdateGraphKernelInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format,
&graph_input_type);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
}
auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
UpdateEquivFormat(output_index, node_list, func_graph, mng);
node_list.clear();
input_list.clear();
output_list.clear();
kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
// update graph input format and dtype use inner ops.
UpdateFormatsAndDtypes(kernel_node, node_list, input_list, mng, default_format, &graph_input_format,
&graph_input_type);
// set fix_precision for kernel when the me prim has fix_precision attr
UpdateKernelInfo(node_list);
output_index = kernel::GetOutputIndex(node_list, input_list, output_list);
SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type);
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -24,7 +24,7 @@ namespace device {
namespace ascend { namespace ascend {
void GraphDescReporter::ReportData() { void GraphDescReporter::ReportData() {
for (const auto &node : cnode_list_) { for (const auto &node : cnode_list_) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
MS_LOG(WARNING) << "Skip non tbe kernel"; MS_LOG(WARNING) << "Skip non tbe kernel";
continue; continue;
} }

View File

@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() {
size_t task_index = 0; size_t task_index = 0;
for (const auto &node : cnode_list_) { for (const auto &node : cnode_list_) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AUTO_DIFF_KERNEL) { if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
MS_LOG(WARNING) << "Skip non tbe kernel"; MS_LOG(WARNING) << "Skip non tbe kernel";
++task_index; ++task_index;
continue; continue;

View File

@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
MS_EXCEPTION_IF_NULL(anf_node_ptr); MS_EXCEPTION_IF_NULL(anf_node_ptr);
if (anf_node_ptr->inputs().size() != 2) { if (anf_node_ptr->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2."; // akg process
// set atomic clean addr
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs);
auto graph = anf_node_ptr->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
if (node_users[anf_node_ptr].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
}
auto depend_node = node_users[anf_node_ptr].pop().first;
if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
MS_LOG(EXCEPTION) << "Checking Depend node failed";
}
if (node_users[depend_node].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty.";
}
auto post_node = node_users[depend_node].pop().first;
for (auto index : clean_output_indexs) {
auto device_address = AnfAlgo::GetOutputAddr(post_node, index);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
input->size = device_address->size_;
kernel_inputs->push_back(input);
}
MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size();
}
return;
} }
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
input->size = device_address->size_; input->size = device_address->size_;
kernel_inputs->push_back(input); kernel_inputs->push_back(input);
} }
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
} }
// set clean workspace address // set clean workspace address
if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {

View File

@ -16,7 +16,7 @@
#include "device/gpu/gpu_kernel_build.h" #include "device/gpu/gpu_kernel_build.h"
#include <string> #include <string>
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/akg/akgkernelbuild.h" #include "kernel/akg/akg_kernel_build.h"
#include "kernel/akg/gpu/akg_gpu_kernel_build.h" #include "kernel/akg/gpu/akg_gpu_kernel_build.h"
#include "kernel/gpu/gpu_kernel_factory.h" #include "kernel/gpu/gpu_kernel_factory.h"
#include "operator/ops.h" #include "operator/ops.h"
@ -37,7 +37,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
continue; continue;
} }
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) { if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) {
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
if (!gpu_kernel_ptr) { if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed";

View File

@ -184,7 +184,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
if (!result) { if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AUTO_DIFF_KERNEL; kernel_type = AKG_KERNEL;
} }
if (!result) { if (!result) {

View File

@ -26,6 +26,8 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive_base.h" #include "ir/primitive_base.h"
#include "operator/ops.h"
namespace mindspore { namespace mindspore {
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() {
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) { if (cnode == nullptr) {
return false;
}
if (value != nullptr) {
return cnode->IsApply(value); return cnode->IsApply(value);
} }
return false; const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim != nullptr;
} }
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {

View File

@ -124,6 +124,7 @@ class AnfNode : public Base {
const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
KernelInfoDevice *kernel_info() { return kernel_info_.get(); } KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
AbstractBasePtr abstract() const { return abstract_; } AbstractBasePtr abstract() const { return abstract_; }
@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) {
std::string GetCNodeFuncName(CNodePtr cnode); std::string GetCNodeFuncName(CNodePtr cnode);
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input // used to check whether an AnfNode is a cnode with a kind of Primitive as first input
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value); bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
// used to check whether an AnfNode is a cnode with a Primitive as first input // used to get PrimitivePtr from a cnode first input
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// used to check whether an AnfNode is a valuenode having some Primitive value // used to check whether an AnfNode is a valuenode having some Primitive value

View File

@ -70,7 +70,7 @@ std::string CNode::fullname_with_scope() {
} }
fullname_with_scope_ = name; fullname_with_scope_ = name;
} else { } else {
// cnode input 0 should be primitive ptr // cnode input 0 should be primitive ptr or funcgraph ptr
auto value_ptr = input(0)->cast<ValueNodePtr>(); auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) { if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
@ -84,11 +84,23 @@ std::string CNode::fullname_with_scope() {
return fullname_with_scope_; return fullname_with_scope_;
} }
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); auto prim = input_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(scope()); MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim); fullname_with_scope_ = scope()->name() + "/";
fullname_with_scope_ = if (prim != nullptr) {
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); fullname_with_scope_ += prim->name();
} else {
auto func_graph = input_value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
fullname_with_scope_ += "GraphKernel_" + fg_name;
} else {
fullname_with_scope_ += func_graph->ToString();
}
}
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
} }
return fullname_with_scope_; return fullname_with_scope_;

View File

@ -77,9 +77,9 @@ class Bool : public Number {
TypeId generic_type_id() const override { return kNumberTypeBool; } TypeId generic_type_id() const override { return kNumberTypeBool; }
TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } TypePtr DeepCopy() const override { return std::make_shared<Bool>(); }
std::string ToString() const override { return "Bool_"; } std::string ToString() const override { return "Bool"; }
std::string ToReprString() const override { return "bool_"; } std::string ToReprString() const override { return "bool"; }
std::string DumpText() const override { return "Bool_"; } std::string DumpText() const override { return "Bool"; }
}; };
// Int // Int

View File

@ -34,7 +34,7 @@ namespace mindspore {
* Methods of Graph * Methods of Graph
*/ */
FuncGraph::FuncGraph() FuncGraph::FuncGraph()
: flags_(), : attrs_(),
transforms_(), transforms_(),
parameter_default_value_(), parameter_default_value_(),
seen_(0), seen_(0),
@ -95,13 +95,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
return p; return p;
} }
bool FuncGraph::has_flag(const std::string &flag) { bool FuncGraph::has_flag(const std::string &key) {
if (flags_.count(flag)) { auto iter = attrs_.find(key);
return flags_[flag]; if (iter != attrs_.cend()) {
if (iter->second->isa<BoolImm>()) {
return GetValue<bool>(iter->second);
}
MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
} }
return false; return false;
} }
bool FuncGraph::has_attr(const std::string &key) {
auto iter = attrs_.find(key);
return !(iter == attrs_.cend());
}
ValuePtr FuncGraph::get_attr(const std::string &key) {
auto iter = attrs_.find(key);
return iter == attrs_.cend() ? nullptr : iter->second;
}
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {

View File

@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
namespace abstract { namespace abstract {
@ -195,10 +196,19 @@ class FuncGraph : public FuncGraphBase {
void set_is_generate(bool generated) { is_generated_ = generated; } void set_is_generate(bool generated) { is_generated_ = generated; }
bool is_generated() const { return is_generated_; } bool is_generated() const { return is_generated_; }
bool has_flag(const std::string &flag); std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
std::unordered_map<std::string, bool> &flags() { return flags_; } void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; } for (auto &attr : attrs) {
void set_flags(const std::string &key, const bool value) { flags_[key] = value; } attrs_[attr.first] = attr.second;
}
}
bool has_flag(const std::string &key);
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
bool has_attr(const std::string &key);
ValuePtr get_attr(const std::string &key);
void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; } std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) { void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
@ -317,7 +327,7 @@ class FuncGraph : public FuncGraphBase {
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
std::unordered_map<std::string, bool> flags_; std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, FuncGraphTransform> transforms_; std::unordered_map<std::string, FuncGraphTransform> transforms_;
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;

View File

@ -90,6 +90,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_abstract(old_node->abstract()); new_node->set_abstract(old_node->abstract());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
new_node->set_kernel_info(old_node->kernel_info_ptr());
repl_node_[old_node] = new_node; repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node); nodes_.emplace_back(old_node, new_node);
TraceManager::EndTrace(); TraceManager::EndTrace();
@ -211,7 +212,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
*target_func_graph = std::make_shared<FuncGraph>(); *target_func_graph = std::make_shared<FuncGraph>();
(*target_func_graph)->set_flags(func_graph->flags()); (*target_func_graph)->set_attrs(func_graph->attrs());
(*target_func_graph)->set_transforms(func_graph->transforms()); (*target_func_graph)->set_transforms(func_graph->transforms());
(*target_func_graph)->set_has_vararg(func_graph->has_vararg()); (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
@ -636,9 +637,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
} }
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
}
return new_func_graph; return new_func_graph;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -399,8 +399,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
depend_inputs.push_back(*iter); depend_inputs.push_back(*iter);
} }
} }
set_flags(GRAPH_FLAG_HAS_EFFECT, false); set_flag(GRAPH_FLAG_HAS_EFFECT, false);
set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
if (!depend_inputs.empty()) { if (!depend_inputs.empty()) {
SetEffectDepends(depend_inputs); SetEffectDepends(depend_inputs);
} }

View File

@ -9,6 +9,10 @@ if (ENABLE_D)
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"kernel_query.cc" "kernel_query.cc"
"kernel_fusion.cc" "kernel_fusion.cc"
"akg/ascend/*.cc"
"akg/akg_kernel_build.cc"
"akg/akg_kernel_attrs_process.cc"
"akg/akg_kernel_metadata.cc"
"tbe/*.cc" "tbe/*.cc"
"aicpu/*.cc" "aicpu/*.cc"
"rts/*.cc" "rts/*.cc"
@ -33,7 +37,7 @@ if (ENABLE_GPU)
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu/*.cu" "gpu/*.cu"
"akg/gpu/*.cc" "akg/gpu/*.cc"
"akg/akgkernelbuild.cc" "akg/akg_kernel_build.cc"
"akg/akg_kernel_attrs_process.cc" "akg/akg_kernel_attrs_process.cc"
) )

View File

@ -24,7 +24,7 @@
#include <map> #include <map>
#include "device/kernel_runtime.h" #include "device/kernel_runtime.h"
#include "kernel/aicpu/aicpu_kernel_mod.h" #include "kernel/aicpu/aicpu_kernel_mod.h"
#include "kernel/akg/akgkernelbuild.h" #include "kernel/akg/akg_kernel_build.h"
#include "proto/tensor.pb.h" #include "proto/tensor.pb.h"
#include "proto/tensor_shape.pb.h" #include "proto/tensor_shape.pb.h"
#include "proto/attr.pb.h" #include "proto/attr.pb.h"

View File

@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) {
dst_type = "float32"; dst_type = "float32";
} else if (output_type == kFloat16->type_id()) { } else if (output_type == kFloat16->type_id()) {
dst_type = "float16"; dst_type = "float16";
} else if (output_type == kInt32->type_id()) {
dst_type = "int32";
} else {
MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString();
} }
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "kernel/akg/akgkernelbuild.h" #include "kernel/akg/akg_kernel_build.h"
#include <Python.h> #include <Python.h>
#include <sys/types.h> #include <sys/types.h>
#include <signal.h> #include <signal.h>
@ -43,7 +43,9 @@ namespace kernel {
constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200;
constexpr int32_t ARGS_SIZE = 1; constexpr int32_t ARGS_SIZE = 1;
constexpr auto kCompileWithJsonFunc = "compilewithjson"; constexpr auto kCompileWithJsonFunc = "compilewithjson";
// json key // json key
constexpr auto kOpDesc = "op_desc";
constexpr auto kInputDesc = "input_desc"; constexpr auto kInputDesc = "input_desc";
constexpr auto kShape = "shape"; constexpr auto kShape = "shape";
constexpr auto kDataType = "data_type"; constexpr auto kDataType = "data_type";
@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc";
constexpr auto kName = "name"; constexpr auto kName = "name";
constexpr auto kTensorName = "tensor_name"; constexpr auto kTensorName = "tensor_name";
constexpr auto kValue = "value"; constexpr auto kValue = "value";
constexpr auto KInpputNames = "input_names"; constexpr auto KDynInputSizes = "dyn_input_sizes";
constexpr auto KInputNames = "input_names";
constexpr auto KInput = "input"; constexpr auto KInput = "input";
constexpr auto KDtype = "dtype"; constexpr auto KDtype = "dtype";
int AkgKernelBuild::op_cnt_ = 0; namespace {
std::mutex AkgKernelBuild::op_cnt_mtx_; template <typename T>
std::string Vector2Str(const std::vector<T> &inputs) {
if (!inputs.empty()) {
std::ostringstream oss;
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
oss << inputs.back();
return oss.str();
}
return "";
}
} // namespace
std::string PyObjectToStr(PyObject *const PyObj) { std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) {
char *pChar = nullptr; char *pChar = nullptr;
std::string str_res; std::string str_res;
if (PyObj == nullptr) { if (PyObj == nullptr) {
@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) {
return str_res; return str_res;
} }
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) {
if (node_json.count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
return "";
}
auto const &tag_desc = node_json[tag];
nlohmann::json first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
return "";
} else {
first_index = tag_desc[position.first];
}
if (!first_index.is_array() || first_index.size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
return "";
}
auto const &second_index = first_index[position.second];
if (second_index.count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "].";
return "";
}
return second_index[kTensorName];
}
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(node_json);
if (node_json->count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
return;
}
nlohmann::json *tag_desc = &((*node_json)[tag]);
nlohmann::json *first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
return;
} else {
first_index = &((*tag_desc)[position.first]);
}
if (!first_index->is_array() || first_index->size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
return;
}
nlohmann::json *second_index = &((*first_index)[position.second]);
if (second_index->count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "].";
return;
}
(*second_index)[kTensorName] = new_name;
return;
}
int AkgKernelBuild::op_cnt_ = 0;
std::mutex AkgKernelBuild::op_cnt_mtx_;
std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
std::string device; std::string device;
@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
// dtype : float16 // dtype : float16
auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index);
TypePtr type_ptr = TypeIdToType(type_id); std::string dtype = TypeId2String(type_id);
MS_EXCEPTION_IF_NULL(type_ptr);
std::string dtype = type_ptr->ToString();
dtype = Dtype2String(dtype);
if (dtype.empty()) { if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. ";
return false; return false;
@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json[kDataType] = dtype; input_desc_json[kDataType] = dtype;
input_desc_json[kName] = op_input_name; input_desc_json[kName] = op_input_name;
input_desc_json[kTensorName] = input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
op_input_name + "_" + std::to_string(real_input_index) + "_" + std::to_string(input_i); auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
<< "], value: " << input_desc_json[kValue];
input_shape.clear();
}
if (input_shape.empty()) {
input_shape.push_back(1);
}
input_desc_json[kShape] = input_shape;
input_list.emplace_back(input_desc_json); input_list.emplace_back(input_desc_json);
real_input_index++;
} }
inputs_json->emplace_back(input_list); inputs_json->emplace_back(input_list);
real_input_index++;
} }
return true; return true;
} }
@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
for (size_t i = 0; i < output_tensor_num; i++) { for (size_t i = 0; i < output_tensor_num; i++) {
nlohmann::json output_json; nlohmann::json output_json;
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
TypePtr type_ptr = TypeIdToType(type_id); std::string dtype = TypeId2String(type_id);
MS_EXCEPTION_IF_NULL(type_ptr);
std::string dtype = type_ptr->ToString();
dtype = Dtype2String(dtype);
if (dtype.empty()) { if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. ";
return false; return false;
@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
std::string output_name = outputs[i]->name(); std::string output_name = outputs[i]->name();
output_json[kDataType] = dtype; output_json[kDataType] = dtype;
output_json[kName] = output_name; output_json[kName] = output_name;
output_json[kTensorName] = output_name + "_" + std::to_string(i); output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i);
outputs_json->push_back(output_json); outputs_json->push_back(output_json);
} }
@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const
MS_EXCEPTION_IF_NULL(op_info_ptr); MS_EXCEPTION_IF_NULL(op_info_ptr);
// get basic params from currentNodeOpDesc // get basic params from currentNodeOpDesc
(*node_json)["platform"] = "AKG";
(*node_json)[kName] = op_name; (*node_json)[kName] = op_name;
(*node_json)["fusion_type"] = AnfAlgo::GetFusionType(anf_node);
(*node_json)["impl_path"] = op_info_ptr->impl_path(); (*node_json)["impl_path"] = op_info_ptr->impl_path();
(*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node);
(*node_json)["composite"] = false;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
ValuePtr input_names_v = primitive->GetAttr(KInpputNames); ValuePtr input_names_v = primitive->GetAttr(KInputNames);
if (input_names_v == nullptr) { if (input_names_v == nullptr) {
MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "].";
return false; return false;
@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod
(void)alarm(0); (void)alarm(0);
if (pRes == nullptr) { if (pRes == nullptr) {
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< PyObjectToStr(pArg) << ")."; << AkgKernelBuild::PyObjectToStr(pArg) << ").";
return nullptr; return nullptr;
} }
if (PyObject_IsTrue(pRes) != 1) { if (PyObject_IsTrue(pRes) != 1) {
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< PyObjectToStr(pArg) << ")."; << AkgKernelBuild::PyObjectToStr(pArg) << ").";
return nullptr; return nullptr;
} }
@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto
<< "]"; << "]";
return kernel_pack; return kernel_pack;
} }
size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
size_t index = input_tensor_idx_.size();
input_tensor_idx_[input_node] = index;
}
return input_tensor_idx_[input_node];
}
size_t AkgKernelBuild::GetOutputTensorIdxInc() {
size_t idx = output_tensor_idx_++;
return idx;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -32,29 +32,45 @@ namespace mindspore {
namespace kernel { namespace kernel {
class AkgKernelBuild { class AkgKernelBuild {
public: public:
AkgKernelBuild() = default; AkgKernelBuild() {
input_tensor_idx_ = {};
output_tensor_idx_ = 0;
}
~AkgKernelBuild() = default; ~AkgKernelBuild() = default;
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size, KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size); std::vector<size_t> *const output_size);
static std::string GetProcessor(const AnfNodePtr &anf_node);
static std::string PyObjectToStr(PyObject *const PyObj);
private: protected:
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json); const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc();
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
nlohmann::json *const node_json); nlohmann::json *const node_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
std::string GetProcessor(const AnfNodePtr &anf_node);
static int op_cnt_; static int op_cnt_;
// lock for variable fusionOpCnt in singleton mode // lock for variable fusionOpCnt in singleton mode
static std::mutex op_cnt_mtx_; static std::mutex op_cnt_mtx_;
std::string json_name_; std::string json_name_;
std::string json_info_; std::string json_info_;
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
size_t output_tensor_idx_;
}; };
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json);
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/akg_kernel_metadata.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
void AkgMetadataInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
for (size_t i = 0; i < support_devices.size(); i++) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
if (op_info_ptr == nullptr) {
continue;
}
if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) {
MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed.";
} else {
MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "].";
break;
}
}
if (kernel_info_list->empty()) {
MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "].";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
#include <string>
#include <vector>
#include <unordered_map>
#include <memory>
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_

View File

@ -0,0 +1,385 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <Python.h>
#include "ir/dtype.h"
#include "ir/func_graph.h"
#include "kernel/kernel.h"
#include "kernel/common_utils.h"
#include "kernel/tbe/tbe_utils.h"
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
#include "kernel/akg/akg_kernel_attrs_process.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
constexpr int32_t PARALLEL_ARGS_SIZE = 3;
constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300;
constexpr auto kOpDesc = "op_desc";
constexpr auto kShape = "shape";
constexpr auto kDataType = "data_type";
constexpr auto kInputDesc = "input_desc";
constexpr auto kOutputDesc = "output_desc";
constexpr auto kTensorName = "tensor_name";
constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel";
constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler";
bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
}
kernel_json_ = node_json.dump();
if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
if (anf_nodes.empty() || input_list.empty()) {
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
<< "].";
return false;
}
MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list ["
<< input_list.size() << "].";
std::map<AnfNodePtr, nlohmann::json> node_json_map;
for (auto const &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
return false;
}
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed.";
return false;
}
// No need for composite op.
node_json.erase("id");
node_json.erase("op");
node_json.erase("composite");
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("fusion") != nullptr) {
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
}
node_json_map[anf_node] = node_json;
}
for (auto const &anf_node : anf_nodes) {
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName(node_json_map[anf_node], kInputDesc, std::make_pair(i, j));
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
std::string new_tensor_name =
GetTensorName(node_json_map[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
} else {
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] is out input.";
}
real_input_index++;
}
}
}
nlohmann::json fused_node_json;
std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
fused_node_json[kOpDesc] = node_json_desc;
nlohmann::json inputs_json;
auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) {
auto tmp_input = input_index[i];
auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first);
std::string dtype = TypeId2String(type_id);
nlohmann::json input_desc_json;
input_desc_json[kTensorName] = GetTensorName(node_json_map[tmp_input.first], kInputDesc, tmp_input.second);
input_desc_json[kDataType] = dtype;
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first);
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
}
fused_node_json[kInputDesc] = inputs_json;
nlohmann::json outputs_json;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
for (size_t i = 0; i < output_index.size(); ++i) {
auto tmp_output = output_index[i];
bool found = false;
nlohmann::json output_desc_json;
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
if (tmp_output.first == input_list[input_i]) {
output_desc_json = inputs_json[input_i][0];
found = true;
break;
}
}
if (!found) {
auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second);
std::string dtype = TypeId2String(type_id);
output_desc_json[kTensorName] =
GetTensorName(node_json_map[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kDataType] = dtype;
auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second);
if (output_shape.empty()) {
output_shape.push_back(1);
}
output_desc_json[kShape] = output_shape;
}
outputs_json.emplace_back(output_desc_json);
}
fused_node_json[kOutputDesc] = outputs_json;
size_t hash_id = std::hash<std::string>()(fused_node_json.dump());
json_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)json_name_.append(fg_attr).append("_");
}
(void)json_name_.append(std::to_string(hash_id));
fused_node_json["composite_graph"] = fg->ToString();
fused_node_json["op"] = json_name_;
fused_node_json["platform"] = "AKG";
fused_node_json["process"] = "aicore";
fused_node_json["composite"] = true;
kernel_json_ = fused_node_json.dump();
if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) {
MS_EXCEPTION_IF_NULL(p_args);
*p_args = PyTuple_New(PARALLEL_ARGS_SIZE);
PyObject *arg1 = PyList_New(kernel_jsons.size());
for (int i = 0; i < PyList_Size(arg1); ++i) {
PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str()));
}
PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM);
PyObject *arg3 = Py_BuildValue("i", TIME_OUT);
(void)PyTuple_SetItem(*p_args, 0, arg1);
(void)PyTuple_SetItem(*p_args, 1, arg2);
(void)PyTuple_SetItem(*p_args, 2, arg3);
}
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
std::vector<std::string> jsons;
std::unordered_set<std::string> json_name_set;
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes;
for (const auto &[builder, anf_node] : build_args) {
MS_EXCEPTION_IF_NULL(anf_node);
auto json_name = builder.json_name();
MS_LOG(DEBUG) << "Akg start compile op: " << json_name;
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
continue;
}
if (json_name_set.count(json_name) != 0) {
repeat_nodes.push_back({builder, anf_node});
continue;
}
json_name_set.insert(json_name);
auto node_json = builder.kernel_json();
kernel::SaveJsonInfo(json_name, node_json);
jsons.push_back(node_json);
}
// No nodes need to be compiled!
if (jsons.empty()) {
return true;
}
// Try to call python method to compile nodes parallely.
PyObject *p_module = nullptr;
PyObject *p_func = nullptr;
PyObject *p_arg = nullptr;
PyObject *p_res = nullptr;
p_module = PyImport_ImportModule(kMultiProcModule);
if (p_module == nullptr) {
MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "].";
return false;
}
p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc);
GenParallelCompileFuncArgs(jsons, &p_arg);
MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size()
<< " Akg kernels parallelly.";
p_res = PyEval_CallObject(p_func, p_arg);
if (p_res == nullptr) {
PyErr_Print();
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false;
}
if (PyObject_IsTrue(p_res) != 1) {
PyErr_Print();
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false;
}
// All unique done here, cache them and set kernel.
for (const auto &[builder, anf_node] : build_args) {
auto json_name = builder.json_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!";
}
// Handle repeated nodes.
for (const auto &[builder, anf_node] : repeat_nodes) {
auto node_json = builder.kernel_json();
auto json_name = builder.json_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (cached_kernel_pack == nullptr) return false;
MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
}
return true;
}
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node;
for (const auto &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
AkgAscendKernelBuilder akg_cce_kernel_builder;
KernelPackPtr kernel_pack = nullptr;
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::IsGraphKernel(cnode)) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]";
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) {
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "].";
}
} else {
if (!akg_cce_kernel_builder.CollectJson(anf_node)) {
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "].";
}
}
json_and_node.push_back({akg_cce_kernel_builder, anf_node});
}
if (json_and_node.empty()) {
MS_LOG(DEBUG) << "There is no kernel needed to be compiled.";
return true;
}
return AkgOpParallelBuild(json_and_node);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
#include <string>
#include <memory>
#include <vector>
#include "ir/anf.h"
#include "kernel/kernel.h"
#include "kernel/akg/akg_kernel_build.h"
namespace mindspore {
namespace kernel {
class AkgAscendKernelBuilder : public AkgKernelBuild {
public:
AkgAscendKernelBuilder() = default;
~AkgAscendKernelBuilder() = default;
bool CollectJson(const AnfNodePtr &anf_node);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
std::string json_name() const { return json_name_; }
std::string kernel_json() const { return kernel_json_; }
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
private:
std::string kernel_json_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
};
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_

View File

@ -0,0 +1,181 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
#include <algorithm>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "nlohmann/json.hpp"
#include "runtime/rt.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
using std::fstream;
using std::map;
using std::mutex;
using std::string;
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
using tbe::KernelManager;
constexpr uint32_t DEFAULT_BLOCK_DIM = 1;
/**
* @brief infotable contain func_stub\blockdim\kernel file buffer
*/
AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {}
void AkgKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; }
void AkgKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
void AkgKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; }
const std::vector<size_t> &AkgKernelMod::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
void DumpData(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
const char *dump_data = getenv("MS_KERNEL_DUMP_DATA");
if (dump_data) {
int idx = 0;
for (const auto &x : inputs) {
std::vector<char> buf(x->size);
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
RT_MEMCPY_DEVICE_TO_HOST)) {
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
return;
}
std::string file_name("input_");
file_name += std::to_string(idx);
std::ofstream file(file_name, std::ios::binary);
if (file.is_open()) {
(void)file.write(buf.data(), SizeToLong(buf.size()));
file.close();
idx++;
} else {
MS_LOG(ERROR) << "Open file failed.";
return;
}
}
idx = 0;
for (const auto &x : outputs) {
std::vector<char> buf(x->size);
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
RT_MEMCPY_DEVICE_TO_HOST)) {
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
return;
}
std::string file_name("output_");
file_name += std::to_string(idx);
std::ofstream file(file_name, std::ios::binary);
if (file.is_open()) {
(void)file.write(buf.data(), SizeToLong(buf.size()));
file.close();
idx++;
} else {
MS_LOG(ERROR) << "Open file failed.";
return;
}
}
}
}
bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (stream_ptr == 0) {
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
return false;
}
if (kernel_pack_ == nullptr) {
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
return false;
}
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(ERROR) << "GenFuncStub failed.";
return false;
}
// pack all addresses into a vector.
std::vector<void *> runtime_args;
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args),
[](const AddressPtr &output) -> void * { return output->addr; });
rtL2Ctrl_t *l2ctrl = nullptr;
auto stream = reinterpret_cast<rtStream_t *>(stream_ptr);
if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(),
SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) {
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
return false;
}
DumpData(inputs, outputs);
return true;
}
std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
if (kernel_pack_ == nullptr) {
MS_LOG(EXCEPTION) << "kernel pack should not be nullptr.";
}
std::vector<uint8_t> args;
uint32_t args_size = 0;
std::vector<uint8_t> sm_desc;
void *binary = nullptr;
uint32_t binary_size = 0;
std::vector<uint8_t> meta_data;
std::vector<void *> input_data_addrs;
std::vector<void *> output_data_addrs;
std::vector<void *> workspace_addrs;
// pack all addresses into a vector.
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
[](const AddressPtr &output) -> void * { return output->addr; });
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_);
MS_LOG(DEBUG) << "The block_dim is:" << block_dim;
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs,
output_data_addrs, workspace_addrs);
return {task_info_ptr};
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
#include <string>
#include <vector>
#include <memory>
#include "kernel/ascend_kernel_mod.h"
#include "kernel/tbe/tbe_utils.h"
namespace mindspore {
namespace kernel {
class AkgKernelMod : public AscendKernelMod {
public:
explicit AkgKernelMod(const KernelPackPtr &kernel_pack);
~AkgKernelMod() final {}
void SetInputSizeList(const std::vector<size_t> &size_list);
void SetOutputSizeList(const std::vector<size_t> &size_list);
void SetWorkspaceSizeList(const std::vector<size_t> &size_list);
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
private:
KernelPackPtr kernel_pack_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
using AkgKernelModPtr = std::shared_ptr<AkgKernelMod>;
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_

View File

@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/akg/akgkernelbuild.h" #include "kernel/akg/akg_kernel_build.h"
#include "kernel/akg/gpu/akg_gpu_kernel_mod.h" #include "kernel/akg/gpu/akg_gpu_kernel_mod.h"
#include "common/utils.h" #include "common/utils.h"

View File

@ -23,6 +23,11 @@
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "common/utils.h" #include "common/utils.h"
#include "ir/manager.h"
#include "ir/meta_tensor.h"
#include "ir/func_graph.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -48,12 +53,6 @@ const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeBool, "bool"},
}; };
const std::map<std::string, std::string> DATATYPE_STRING_MAP{
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "bool"}, {"Float64", "double"},
};
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = { const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
{"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
@ -243,14 +242,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) {
} }
} }
std::string Dtype2String(const std::string &dtypes) {
auto iter = DATATYPE_STRING_MAP.find(dtypes);
if (iter == DATATYPE_STRING_MAP.end()) {
MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
}
return iter->second;
}
std::string TypeId2String(TypeId type_id) { std::string TypeId2String(TypeId type_id) {
auto iter = type_id_str_map.find(type_id); auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) { if (iter == type_id_str_map.end()) {
@ -361,7 +352,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
output_num = 1; output_num = 1;
} else { } else {
if (output_idx < real_output_num) { if (output_idx < real_output_num) {
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
output_num = 1; output_num = 1;
} }
} }
@ -403,7 +394,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu
} }
if (imply_type == kAKG) { if (imply_type == kAKG) {
builder->SetKernelType(AUTO_DIFF_KERNEL); builder->SetKernelType(AKG_KERNEL);
} else if (imply_type == kAICPU) { } else if (imply_type == kAICPU) {
builder->SetKernelType(AICPU_KERNEL); builder->SetKernelType(AICPU_KERNEL);
} else { } else {
@ -634,5 +625,256 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie
} }
unique_grad->indices_size_ = unique_indices_size + 1; unique_grad->indices_size_ = unique_indices_size + 1;
} }
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode == nullptr) {
return AnfAlgo::VisitKernel(anf_node, 0);
} else {
return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
}
}
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list) {
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
for (size_t i = 0; i < input_list.size(); ++i) {
auto const &input = input_list[i];
MS_EXCEPTION_IF_NULL(input);
bool found = false;
// using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
auto mng = input->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
const NodeUsersMap &users = mng->node_users();
auto input_users = users.find(input);
if (input_users == users.end() || input_users->second.empty()) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] has no users.";
}
for (auto const &input_user : input_users->second) {
for (auto const &anf_node : node_list) {
if (anf_node != input_user.first) {
continue;
}
std::vector<int> dyn_input_sizes;
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
}
if (dyn_input_sizes.empty()) {
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
found = true;
break;
} else {
int used_as_idx = input_user.second - 1;
int accum_idx = 0;
size_t dyn_i = 0;
for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
accum_idx += dyn_input_sizes[dyn_i];
if (used_as_idx < accum_idx) {
input_index.push_back(std::make_pair(
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
break;
}
}
if (dyn_i != dyn_input_sizes.size()) {
found = true;
break;
}
}
}
if (found) {
break;
}
}
if (!found) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] found no related kernel info.";
}
}
return input_index;
}
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
std::vector<std::pair<AnfNodePtr, size_t>> output_index;
for (size_t i = 0; i < output_list.size(); ++i) {
auto const &output = output_list[i];
MS_EXCEPTION_IF_NULL(output);
bool found = false;
auto pree_node = AnfAlgo::VisitKernel(output, 0);
auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
if (pos != std::end(node_list)) {
output_index.push_back(pree_node);
continue;
}
auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
if (ret != std::end(input_list)) {
output_index.push_back(std::make_pair(pree_node.first, 0));
found = true;
}
if (!found) {
MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
<< output->func_graph()->ToString() << "] found no related kernel info.";
}
}
return output_index;
}
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
MS_EXCEPTION_IF_NULL(node_list);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
for (auto const &node : node_lists) {
if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
node_list->push_back(node);
}
}
}
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
MS_EXCEPTION_IF_NULL(node_list);
MS_EXCEPTION_IF_NULL(input_list);
MS_EXCEPTION_IF_NULL(output_list);
MS_EXCEPTION_IF_NULL(func_graph);
GetValidKernelNodes(func_graph, node_list);
auto parameters = func_graph->parameters();
input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
auto func_output = func_graph->output();
MS_EXCEPTION_IF_NULL(func_output);
if (func_output->isa<CNode>()) {
// multi output.
auto cnode = func_output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
auto input_node = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(input_node);
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
}
} else {
// single output.
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
}
} else {
// single output.
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
}
}
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(node_json);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (!IsValueNode<tensor::Tensor>(input_node)) {
return false;
}
auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
if (tensor == nullptr) {
return false;
}
auto type_id = tensor->data_type();
auto *data = tensor->data_c();
MS_EXCEPTION_IF_NULL(data);
if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
// not const tensor.
MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
}
if (type_id == kFloat32->type_id()) {
float *val = static_cast<float *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
return true;
} else if (type_id == kFloat16->type_id()) {
float16 *val = static_cast<float16 *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = static_cast<float>(val[0]);
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
return true;
} else if (type_id == kInt32->type_id()) {
int *val = static_cast<int *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
return true;
}
MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
return false;
}
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_list);
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::IsRealKernel(output)) {
// single output.
node_list->push_back(std::make_pair(output, 0));
return;
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
// multi output.
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
node_list->push_back(in_with_idx);
}
return;
}
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
<< " of graph: " << func_graph->ToString();
}
bool IsWeightBoundary(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -20,9 +20,12 @@
#include <dirent.h> #include <dirent.h>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include <nlohmann/json.hpp>
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/oplib/opinfo.h" #include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
@ -79,13 +82,11 @@ bool CheckCache(const std::string &kernel_name);
KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor);
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
TypeId DtypeToTypeId(const std::string &dtypes); TypeId DtypeToTypeId(const std::string &dtypes);
std::string Dtype2String(const std::string &dtypes);
std::string Dtype2ShortType(const std::string &dtypes); std::string Dtype2ShortType(const std::string &dtypes);
std::string TypeId2String(TypeId type_id); std::string TypeId2String(TypeId type_id);
size_t GetDtypeNbyte(const std::string &dtypes); size_t GetDtypeNbyte(const std::string &dtypes);
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
bool IsAtomicNode(const CNodePtr &kernel_node);
void SaveJsonInfo(const std::string &json_name, const std::string &info); void SaveJsonInfo(const std::string &json_name, const std::string &info);
std::string GetProcessor(const AnfNodePtr &anf_node); std::string GetProcessor(const AnfNodePtr &anf_node);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
@ -94,6 +95,18 @@ void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGr
size_t outer_dim); size_t outer_dim);
void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim); size_t outer_dim);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list);
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
bool IsWeightBoundary(const AnfNodePtr &node);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -17,7 +17,7 @@
#include <fstream> #include <fstream>
#include "mindspore/ccsrc/kernel/kernel.h" #include "mindspore/ccsrc/kernel/kernel.h"
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/akg/akgkernelbuild.h" #include "kernel/akg/akg_kernel_build.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"

View File

@ -27,7 +27,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
namespace kernel { namespace kernel {

View File

@ -21,6 +21,7 @@
#include "kernel/rts/rt_kernel_info.h" #include "kernel/rts/rt_kernel_info.h"
#include "kernel/hccl/hccl_kernel_metadata.h" #include "kernel/hccl/hccl_kernel_metadata.h"
#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
namespace mindspore { namespace mindspore {
@ -59,10 +60,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
} }
} }
} // namespace } // namespace
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
void KernelQueryAll(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list); AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) { if (!kernel_info_list->empty()) {
@ -82,6 +87,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
} }
}
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
KernelType kernel_type) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
switch (kernel_type) {
case KernelType::AKG_KERNEL:
AkgMetadataInfo(kernel_node, kernel_info_list);
break;
default:
KernelQueryAll(kernel_node, kernel_info_list);
break;
}
if (kernel_info_list->empty()) {
MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!";
}
// check output
FilterInvalidKernelInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list);
} }

View File

@ -25,7 +25,8 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);

View File

@ -272,8 +272,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice); bool is_gpu = (context->device_target() == kGPUDevice);
if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
(!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size(); << ", current op num: " << op_info_.size();
return nullptr; return nullptr;

View File

@ -347,7 +347,7 @@ static int TypeStrToDstType(const std::string &type_str) {
ret = 4; ret = 4;
} else if (type_str == "UInt64") { } else if (type_str == "UInt64") {
ret = 10; ret = 10;
} else if (type_str == "Bool_") { } else if (type_str == "Bool") {
ret = 12; ret = 12;
} else { } else {
MS_LOG(INFO) << "Error type str is invailed: " << type_str; MS_LOG(INFO) << "Error type str is invailed: " << type_str;

View File

@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
const std::map<std::string, std::string> type_str_maps = { const std::map<std::string, std::string> type_str_maps = {
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"}, {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"},
}; };
const std::unordered_map<std::string, size_t> type_nbyte_maps = { const std::unordered_map<std::string, size_t> type_nbyte_maps = {

View File

@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map"); ptrGraph->debug_info()->set_name("hyper_map");
AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;
@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu
MS_EXCEPTION_IF_NULL(a_tuple); MS_EXCEPTION_IF_NULL(a_tuple);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("tail"); ret->debug_info()->set_name("tail");
AnfNodePtr ptrTup = ret->add_parameter(); AnfNodePtr ptrTup = ret->add_parameter();
@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list
MS_EXCEPTION_IF_NULL(a_list); MS_EXCEPTION_IF_NULL(a_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("tail"); ret->debug_info()->set_name("tail");
AnfNodePtr ptrList = ret->add_parameter(); AnfNodePtr ptrList = ret->add_parameter();
@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
} }
b->set_flags(FUNC_GRAPH_FLAG_CORE, true); b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
b->set_output(b->NewCNode(grads)); b->set_output(b->NewCNode(grads));
fg->set_flags(FUNC_GRAPH_FLAG_CORE, true); fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
return fg; return fg;
@ -504,7 +504,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args, const std::vector<AnfNodePtr> &params_list, const std::vector<AnfNodePtr> &args,
bool applyJ) { bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
auto weights_node = weights; auto weights_node = weights;
if (weights == nullptr && !args.empty()) { if (weights == nullptr && !args.empty()) {
@ -625,7 +625,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
std::ostringstream ss; std::ostringstream ss;
ss << "grad{" << nparam << "}"; ss << "grad{" << nparam << "}";
dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true); dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
dfBuilder->debug_info()->set_name(ss.str()); dfBuilder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = dfBuilder->add_parameter(); ParameterPtr param_graph = dfBuilder->add_parameter();
@ -671,7 +671,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis
} }
FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg_ptr->debug_info()->set_name("list_map"); fg_ptr->debug_info()->set_name("list_map");
AnfNodePtr fn = fg_ptr->add_parameter(); AnfNodePtr fn = fg_ptr->add_parameter();
@ -741,7 +741,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
// cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
fgtrue_ptr->debug_info()->set_name("ftrue"); fgtrue_ptr->debug_info()->set_name("ftrue");
fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
auto inputs = fgtrue_output_cnode->inputs(); auto inputs = fgtrue_output_cnode->inputs();
@ -751,7 +751,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
fgfalse_ptr->debug_info()->set_name("ffalse"); fgfalse_ptr->debug_info()->set_name("ffalse");
fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fgfalse_ptr->set_output(resl); fgfalse_ptr->set_output(resl);
AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
@ -808,7 +808,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
} }
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr p_tup_a = ret->add_parameter(); AnfNodePtr p_tup_a = ret->add_parameter();
AnfNodePtr p_tup_b = ret->add_parameter(); AnfNodePtr p_tup_b = ret->add_parameter();
@ -912,7 +912,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr p_tuple = ret->add_parameter(); AnfNodePtr p_tuple = ret->add_parameter();
(void)ret->add_parameter(); (void)ret->add_parameter();
@ -941,7 +941,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
AbstractBasePtrList branches = branches_abs->elements(); AbstractBasePtrList branches = branches_abs->elements();
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr functions = ret_graph->add_parameter(); AnfNodePtr functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter(); auto index = ret_graph->add_parameter();

View File

@ -304,7 +304,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi
} }
auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters());
func_graph->set_output(new_cnode); func_graph->set_output(new_cnode);
func_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
return func_graph; return func_graph;
} }
} // namespace prim } // namespace prim

View File

@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &
MS_EXCEPTION_IF_NULL(arg0_list); MS_EXCEPTION_IF_NULL(arg0_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("append"); ret->debug_info()->set_name("append");
AnfNodePtr arg0_node = ret->add_parameter(); AnfNodePtr arg0_node = ret->add_parameter();

View File

@ -51,8 +51,8 @@ AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &f
FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
// Generate func for leaf nodes // Generate func for leaf nodes
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("map"); ptrGraph->debug_info()->set_name("map");
AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;
if (fn_leaf_ == nullptr) { if (fn_leaf_ == nullptr) {
@ -237,8 +237,8 @@ AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, c
FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("map"); ptrGraph->debug_info()->set_name("map");
AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;

View File

@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
auto ret_graph = std::make_shared<FuncGraph>(); auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr fnNode = ret_graph->add_parameter(); AnfNodePtr fnNode = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems; std::vector<AnfNodePtr> elems;

View File

@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
}); });
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
for (size_t idx = 0; idx < args_spec_list.size(); idx++) { for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
(void)ret_graph->add_parameter(); (void)ret_graph->add_parameter();
} }

View File

@ -50,6 +50,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
// Type introspection // Type introspection
const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
@ -166,17 +172,20 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar"); const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd"); const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
// NN // NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
@ -253,6 +262,7 @@ const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast"); const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant"); const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
// Comm ops // Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -59,6 +59,12 @@ extern const PrimitivePtr kPrimBoolNot;
extern const PrimitivePtr kPrimBoolAnd; extern const PrimitivePtr kPrimBoolAnd;
extern const PrimitivePtr kPrimBoolOr; extern const PrimitivePtr kPrimBoolOr;
extern const PrimitivePtr kPrimBoolEq; extern const PrimitivePtr kPrimBoolEq;
extern const PrimitivePtr kPrimGreater;
extern const PrimitivePtr kPrimGreaterEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimNotEqual;
// Type introspection // Type introspection
extern const PrimitivePtr kPrimTypeOf; extern const PrimitivePtr kPrimTypeOf;
@ -157,6 +163,10 @@ extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask; extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad; extern const PrimitivePtr kPrimPad;
extern const PrimitivePtr kPrimArgMaxWithValue; extern const PrimitivePtr kPrimArgMaxWithValue;
extern const PrimitivePtr kPrimRealDiv;
extern const PrimitivePtr kPrimSqrt;
extern const PrimitivePtr kPrimReciprocal;
extern const PrimitivePtr kPrimExpandDims;
// Maths // Maths
extern const PrimitivePtr kPrimTensorAdd; extern const PrimitivePtr kPrimTensorAdd;
@ -183,9 +193,11 @@ extern const PrimitivePtr kPrimCumProd;
extern const PrimitivePtr kPrimSubscalar; extern const PrimitivePtr kPrimSubscalar;
extern const PrimitivePtr kPrimInplaceAdd; extern const PrimitivePtr kPrimInplaceAdd;
extern const PrimitivePtr kPrimInplaceSub; extern const PrimitivePtr kPrimInplaceSub;
extern const PrimitivePtr kPrimPow;
// NN // NN
extern const PrimitivePtr kPrimFlatten; extern const PrimitivePtr kPrimFlatten;
extern const PrimitivePtr kPrimSoftmax;
extern const PrimitivePtr kPrimLogSoftmax; extern const PrimitivePtr kPrimLogSoftmax;
extern const PrimitivePtr kPrimLogSoftmaxGrad; extern const PrimitivePtr kPrimLogSoftmaxGrad;
extern const PrimitivePtr kPrimApplyCenteredRMSProp; extern const PrimitivePtr kPrimApplyCenteredRMSProp;
@ -263,6 +275,7 @@ extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict; extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimMixedPrecisionCast; extern const PrimitivePtr kPrimMixedPrecisionCast;
extern const PrimitivePtr kPrimIsConsant; extern const PrimitivePtr kPrimIsConsant;
extern const PrimitivePtr kPrimEquivFormat;
// Comm ops // Comm ops
extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimAllReduce;

View File

@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>(); k_graph_ = std::make_shared<FuncGraph>();
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
TraceManager::EndTrace(); TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
tape_ = std::make_shared<FuncGraph>(); tape_ = std::make_shared<FuncGraph>();
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
TraceManager::EndTrace(); TraceManager::EndTrace();
dout_ = tape_->add_parameter(); dout_ = tape_->add_parameter();
@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
(void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
(void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
// Reset defer_inline to enable successive inlining // Reset defer_inline to enable successive inlining
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
auto functor = std::make_shared<DFunctor>(primal, resources_); auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(); functor->Init();

View File

@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
f->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
} }
}; };

View File

@ -78,7 +78,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(cons); MS_EXCEPTION_IF_NULL(cons);
auto dt = data->abstract(); auto dt = data->abstract();
MS_EXCEPTION_IF_NULL(dt); if (dt == nullptr) {
return nullptr;
}
if (!dt->isa<AbstractClass>()) { if (!dt->isa<AbstractClass>()) {
MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
} }

View File

@ -0,0 +1,157 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/graph_kernel_reuse.h"
#include <vector>
#include <algorithm>
#include <string>
#include "./common.h"
#include "utils/graph_utils.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) {
if (a->abstract() && b->abstract()) {
auto a_type = a->abstract()->GetTypeTrack();
auto b_type = b->abstract()->GetTypeTrack();
if (a_type != b_type) {
return false;
}
auto a_shape = a->abstract()->GetShapeTrack();
auto b_shape = b->abstract()->GetShapeTrack();
if (a_shape != nullptr && a_shape == b_shape) {
return true;
}
if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() &&
b_shape->isa<abstract::Shape>()) {
return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape();
}
}
return false;
}
bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) {
bool changed = false;
auto fgs = manager->func_graphs();
for (FuncGraphPtr &fg : fgs) {
if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
continue;
}
std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) {
if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) {
FuncGraphPtr new_fg = nullptr;
for (auto &cfg : graph_kernel_ops[key]) {
// If two graphs have different size then continue
auto fg_topos = TopoSort(fg->get_return());
auto cfg_topos = TopoSort(cfg->get_return());
if (fg_topos.size() != cfg_topos.size()) {
continue;
}
// Compare const tensor
bool has_same = true;
for (size_t i = 0; i < fg_topos.size(); ++i) {
if (IsValueNode<tensor::Tensor>(fg_topos[i])) {
if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) {
has_same = false;
break;
}
auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]);
auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]);
if (!tensor1->ValueEqual(*tensor2)) {
has_same = false;
break;
}
}
}
if (!has_same) {
continue;
}
auto fg_input = fg->parameters();
auto cfg_input = cfg->parameters();
if (fg_input.size() != cfg_input.size()) {
continue;
}
// Compare input
for (size_t i = 0; i < fg_input.size(); ++i) {
if (!CompareNode(fg_input[i], cfg_input[i])) {
has_same = false;
break;
}
}
if (!has_same) {
continue;
}
// Compare output
if (!CompareNode(fg->output(), cfg->output())) {
continue;
}
// Find reusable fg
new_fg = cfg;
break;
}
if (new_fg != nullptr) {
// Replace current fg with existing fg
auto users = fg->func_graph_cnodes_index();
for (auto &iter : users) {
auto cnode = iter.first->first->cast<CNodePtr>();
auto new_input = cnode->inputs();
auto main_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(main_graph);
if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
new_input[1] = NewValueNode(new_fg);
} else {
new_input[0] = NewValueNode(new_fg);
}
auto new_cnode = main_graph->NewCNode(new_input);
manager->Replace(iter.first->first, new_cnode);
changed = true;
}
} else {
// Add current fg to map
graph_kernel_ops[key].push_back(fg);
}
}
} else {
graph_kernel_ops[key] = {fg};
}
}
return changed;
}
bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
return DoReplace(manager);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H
#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H
#include <mindspore/ccsrc/session/anf_runtime_algorithm.h>
#include <unordered_map>
#include <string>
#include <vector>
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
// Common subexpression elimination.
class GraphKernelReuse {
public:
GraphKernelReuse() : count(0) {}
virtual ~GraphKernelReuse() = default;
bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
bool chg = ReuseGraphKernel(root, optimizer->resource()->manager());
return chg;
}
bool CompareNode(const AnfNodePtr a, const AnfNodePtr other);
bool DoReplace(const FuncGraphManagerPtr manager);
bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager);
private:
std::unordered_map<std::string, std::vector<FuncGraphPtr>> graph_kernel_ops;
int count;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H

View File

@ -41,6 +41,8 @@
#include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h" #include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -48,7 +50,7 @@ namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
special_op_eliminate_ = special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
@ -90,7 +92,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
replace_refkey_by_param_ = replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
// Gradient transforms // Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
@ -115,6 +116,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Incorporation // Incorporation
incorporate_getitem_set_ = incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ =
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
@ -124,6 +127,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Convert // Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel);
// Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -84,6 +84,7 @@ class OptimizeIRPassLib {
// Incorporation // Incorporation
SubstitutionPtr incorporate_getitem_set_; SubstitutionPtr incorporate_getitem_set_;
SubstitutionPtr incorporate_getitem_from_param_;
SubstitutionPtr incorporate_call_; SubstitutionPtr incorporate_call_;
SubstitutionPtr incorporate_call_switch_; SubstitutionPtr incorporate_call_switch_;
@ -92,6 +93,16 @@ class OptimizeIRPassLib {
// Convert // Convert
SubstitutionPtr print_tuple_wrapper_; SubstitutionPtr print_tuple_wrapper_;
// Unused parameter eliminate
SubstitutionPtr unused_parameter_eliminate_;
SubstitutionPtr unused_output_eliminate_;
// AddN eliminate
SubstitutionPtr addn_eliminate_;
// Fusion
SubstitutionPtr mark_interface_fusion_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action
@ -145,6 +156,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
return IsValueNode<FuncGraph>(inp0); return IsValueNode<FuncGraph>(inp0);
} }
// Check if CNode Input 0 is Func Graph of graph kernel.
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(inp0)) {
auto fg = GetValueNode<FuncGraphPtr>(inp0);
if (fg == nullptr) {
return false;
}
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
return false;
}
// Check if CNode Input 0 is CNode // Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) { inline bool IsCNodeDup(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {

View File

@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor {
AnfNodePtr x_{nullptr}; AnfNodePtr x_{nullptr};
}; };
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class CheckTensorConstant {
public:
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
~CheckTensorConstant() = default;
bool IsTensorConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
return false;
}
}
return true;
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
return false;
}
}
return true;
} else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] != check_value_) {
return false;
}
}
return true;
}
// Un-support Data Types
return false;
}
bool IsTensorScalarConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
return false;
}
return IsTensorConstant(value);
}
private:
int check_value_;
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByZeroOrOne : public AnfVisitor {
public:
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {}
~TensorMultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_ || is_one_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
is_zero_ = false;
}
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) {
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value = node->cast<ValueNodePtr>()->value();
if (!value->isa<tensor::Tensor>()) {
return nullptr;
}
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
return tensor_ptr->data_c(writable);
}
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) {
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
if (x == nullptr) {
std::memset(data, 0, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// x is not nullptr
if (x->isa<CNode>()) {
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
std::vector<int> x_shape = x_abstract->shape()->shape();
if (x_shape != tensor_shape) {
return nullptr;
}
return x;
}
if (!x->isa<ValueNode>()) {
return nullptr;
}
auto x_value = x->cast<ValueNodePtr>()->value();
if (!x_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
return nullptr;
}
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
if (x_tensor_ptr->DataSize() == 1) {
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr));
}
} else {
memcpy(source_data, data, mem_size);
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
private:
bool is_zero_{false}, is_one_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimScalarAdd, X, 0} // {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X} // {prim::kPrimScalarAdd, 0, X}
class AddByZero : public AnfVisitor { class AddByZero : public AnfVisitor {
@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor {
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) { if (node->isa<ValueNode>() &&
((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) {
is_zero_ = true; is_zero_ = true;
return; return;
} }
@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor {
is_zero_ = true; is_zero_ = true;
return; return;
} }
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
is_zero_ = true;
return;
}
x_ = node; x_ = node;
} }
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
}
void Reset() { void Reset() {
x_ = nullptr; x_ = nullptr;
is_zero_ = false; is_zero_ = false;
@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor {
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor { class ConstantDuplicateMul : public AnfVisitor {
public: public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size) {
T *data_1 = reinterpret_cast<T *>(in_data_1);
T *data_2 = reinterpret_cast<T *>(in_data_2);
T *data_out = new T[out_data_size];
if (in_data_1_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[i];
}
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
}
auto value_1 = GetValueNode(vnode_1);
auto value_2 = GetValueNode(vnode_2);
if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
return nullptr;
}
auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
}
std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
int data_out_size = 1;
for (auto it : tensor_out_shape) {
data_out_size *= it;
}
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
void *data_out;
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
// Un-support data types
return nullptr;
}
}
}
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
memcpy(data, data_out, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
// {prim::kPrimMul, Tensor1, {...}} // {prim::kPrimMul, Tensor1, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
if (vnode_ == nullptr || cnode_ == nullptr) { if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr; return nullptr;
} }
if (!IsCNode(c_p_node_)) {
return nullptr;
}
auto tensor1 = vnode_; auto tensor1 = vnode_;
auto mul = cnode_; auto mul = c_p_node_->cast<CNodePtr>();
Reset(); Reset();
// {prim::kPrimMul, Tensor2, {...}} // {prim::kPrimMul, Tensor2, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
if (vnode_ == nullptr || cnode_ == nullptr) { if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr; return nullptr;
} }
auto tensor2 = vnode_; auto tensor2 = vnode_;
auto cnode = cnode_; auto c_p_node = c_p_node_;
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
auto fg = node->func_graph(); auto fg = node->func_graph();
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node);
if (new_mul_tensor == nullptr) {
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg);
}
return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg);
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor {
vnode_ = node; vnode_ = node;
} }
if (IsCNode(node)) { if (IsCNode(node) || IsParam(node)) {
cnode_ = node->cast<CNodePtr>(); c_p_node_ = node;
} }
} }
void Reset() { void Reset() {
vnode_ = nullptr; vnode_ = nullptr;
cnode_ = nullptr; c_p_node_ = nullptr;
} }
private: private:
AnfNodePtr vnode_; AnfNodePtr vnode_;
CNodePtr cnode_; AnfNodePtr c_p_node_;
};
class PowerOneEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (!IsValueNode<Scalar>(inputs[2])) {
return nullptr;
}
auto scalar = GetValueNode<ScalarPtr>(inputs[2]);
if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) {
return inputs[1];
} else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) {
return inputs[1];
}
return nullptr;
}
}; };
// grad = AllReduce(grad) / worker_number // grad = AllReduce(grad) / worker_number
@ -341,17 +699,21 @@ class ArithmeticSimplify {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
: multiply_by_zero_or_one_(), : multiply_by_zero_or_one_(),
tensor_multiply_by_zero_or_one_(),
add_by_zero_(), add_by_zero_(),
tensor_add_by_zero_(), tensor_add_by_zero_(),
identity_(prim::kPrimIdentity), identity_(prim::kPrimIdentity),
opt_update_zero_tensor_(), opt_update_zero_tensor_(),
constant_duplicate_mul_() { constant_duplicate_mul_(),
power_one_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_); eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(power_one_);
} }
~ArithmeticSimplify() = default; ~ArithmeticSimplify() = default;
@ -368,11 +730,13 @@ class ArithmeticSimplify {
private: private:
MultiplyByZeroOrOne multiply_by_zero_or_one_; MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_;
AddByZero add_by_zero_; AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_; TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_; PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_; OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_; ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass

View File

@ -21,6 +21,7 @@
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include <unordered_set>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
@ -28,7 +29,6 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "operator/ops.h" #include "operator/ops.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) {
if (node->func_graph() != nullptr && idx_ >= 0 && fg_ != nullptr) { return nullptr;
auto new_fg = getitem_transform_(fg_, idx_);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(args_);
} }
return nullptr;
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If graph kernel has muti output, do not split.
// some graph kernel output has EnvInstance node or DeadCode node should split.
auto output = fg_->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
auto outputs = output_cnode->inputs();
int real_output_cnt = 0;
for (size_t i = 1; i < outputs.size(); ++i) {
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
real_output_cnt++;
if (real_output_cnt > 1) {
return nullptr;
}
}
}
}
}
auto new_fg = getitem_transform_(fg_, idx_);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(args_);
} }
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor {
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
class IncorporateGetitemFromParam : public AnfVisitor {
public:
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &param, size_t input_idx) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
args_.push_back(cnode->input(input_idx + 1));
return;
}
for (auto &user : node_users[param]) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// we do not process this case.
args_.push_back(cnode->input(input_idx + 1));
return;
}
}
// update new args.
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
// case 1
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
} else {
// case 2
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
auto fg_output = prev_fg->output();
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
<< " should be a make tuple, but got: " << fg_output->DebugString();
return;
}
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
auto new_getitem =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))});
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(SizeToInt(output_i)));
new_getitem->input(2)->set_abstract(aptr);
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
args_.push_back(new_getitem);
}
}
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph() == nullptr) {
return nullptr;
}
Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr) {
return nullptr;
}
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto parameters = fg->parameters();
if (parameters.size() != inputs.size() - 1) {
return nullptr;
}
replace_parameters_ = std::vector<bool>(parameters.size(), false);
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
auto node_fg = node->func_graph();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
Process(node_fg, cnode, parameters[i - 1], i - 1);
} else {
args_.push_back(inputs[i]);
}
}
if (!need_update_) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
auto node_users = mng->node_users();
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
size_t curr_input_idx{0};
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
if (!replace_parameters_[param_i]) {
if (parameters[param_i]->abstract() != nullptr) {
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
}
new_parameters.push_back(new_fg_parameters[param_i]);
curr_input_idx++;
continue;
}
// make a new parameter.
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
auto new_param = std::make_shared<Parameter>(new_fg);
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
// update users of new parameter.
for (auto &user : node_users[new_fg_parameters[param_i]]) {
idx_ = -1;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int32Imm>})(user.first);
if (idx_ == -1) {
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
<< " must be tuple getitem here, but got: " << user.first->DebugString();
return nullptr;
}
if (input_i == IntToSize(idx_)) {
for (auto &sub_user : node_users[user.first]) {
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub_user_cnode);
sub_user_cnode->set_input(sub_user.second, new_param);
(void)mng->Replace(sub_user.first, sub_user_cnode);
}
}
}
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
new_parameters.push_back(new_param);
curr_input_idx++;
}
}
mng->SetParameters(new_fg, new_parameters);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_call = node_fg->NewCNode(args_);
new_call->set_abstract(node->abstract());
return new_call;
}
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int>(vnode->value()); }
void Visit(const CNodePtr &cnode) override {}
void Reset() {
replace_parameters_.clear();
args_.clear();
inputs_num_.clear();
need_update_ = false;
idx_ = -1;
}
private:
std::vector<bool> replace_parameters_{};
std::vector<AnfNodePtr> args_{};
std::vector<size_t> inputs_num_{};
bool need_update_{false};
int idx_{-1};
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
class IncorporateGetitemSwitch : public AnfVisitor { class IncorporateGetitemSwitch : public AnfVisitor {
public: public:

View File

@ -86,20 +86,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
auto &flags = node->func_graph()->flags(); return node->func_graph()->has_flag("inline_inside");
if (flags.find("inline_inside") != flags.end()) {
return flags["inline_inside"];
}
return false;
} }
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
auto &flags = fg->flags();
if (flags.find("core") != flags.end()) {
return flags["core"];
}
return false;
}
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
@ -123,6 +113,13 @@ class InlinerBase : public AnfVisitor {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
return nullptr; return nullptr;
} }
// Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined.
if (fg->nodes().size() - fg->parameters().size() > 1) {
return nullptr;
}
}
Reset(); Reset();
bool is_match = false; bool is_match = false;

View File

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
#include <string>
#include <sstream>
#include <unordered_map>
#include "session/anf_runtime_algorithm.h"
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
namespace mindspore {
namespace opt {
namespace irpass {
static int count = 0;
std::string GetFusionNumber() {
std::stringstream ss;
ss << std::setw(4) << std::setfill('0') << count;
std::string num = ss.str();
++count;
return "_" + num;
}
// Mark CNodes which can be merged in kernel build
class MarkInterfaceFusion : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) {
auto cnode = node->cast<CNodePtr>();
auto condition = cnode->input(1);
std::string cmp;
std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"},
{"LessEqual", "LE"}, {"Less", "LT"},
{"Equal", "EQ"}, {"NotEqual", "NE"}};
if (IsPrimitiveCNode(condition)) {
auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>());
if (cmp_list.count(prim_name) != 0) {
// Mark Select and compare node
cmp = cmp_list[prim_name];
auto cnt = GetFusionNumber();
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition);
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) {
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i));
}
}
}
}
}
return nullptr;
}
void Visit(const AnfNodePtr &) override {}
private:
AnfNodePtr y_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H

View File

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <memory>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor {
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false}; bool has_zero_like_{false};
}; };
// {PrimAddN, {kPrimMakeTuple, Xs}}
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
// case0: AddN(inputs)(inputs size < 2) -> error
// case1: AddN(inputs)(all inputs is ValueNode) -> error
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
class AddNEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
need_update_ = false;
bool changed = false;
do {
changed = false;
changed |= Process(new_fg);
} while (changed);
if (!need_update_) {
return nullptr;
} else {
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
}
bool Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto nodes = TopoSort(func_graph->output());
bool changed = false;
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &tuple_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(tuple_input);
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
auto &tuple_inputs = tuple_input_cnode->inputs();
if (tuple_inputs.size() < 3) {
// case0: inputs size < 2, error
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
}
int valuenode_num =
std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
return accumulator + 1;
} else {
return accumulator;
}
});
if (IntToSize(valuenode_num) == tuple_inputs.size()) {
// case1: all inputs is ValueNode, error
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
}
if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs));
changed = true;
continue;
}
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
if (first_valuenode == tuple_inputs.end()) {
// no ValueNode input found.
continue;
} else {
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
std::vector<AnfNodePtr> make_tuple_new_xs{
NewValueNode(prim::kPrimMakeTuple),
};
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
if (node != *first_valuenode) {
make_tuple_new_xs.push_back(node);
}
});
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add);
changed = true;
continue;
}
}
need_update_ |= changed;
return changed;
}
private:
bool need_update_{false};
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor {
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
if (x_ == nullptr) { if (!IsVNode(node) && x_ == nullptr) {
if (IsValueNode<tensor::Tensor>(node)) { if (IsValueNode<tensor::Tensor>(node)) {
is_tensor_ = true; is_tensor_ = true;
} }

View File

@ -23,6 +23,8 @@
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor {
this->y_ = node; this->y_ = node;
return true; return true;
}; };
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
return y_; return y_;
} }

View File

@ -142,7 +142,7 @@ class ResetDeferInline : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (IsValueNode<FuncGraph>(node)) { if (IsValueNode<FuncGraph>(node)) {
auto fg = GetValueNode<FuncGraphPtr>(node); auto fg = GetValueNode<FuncGraphPtr>(node);
fg->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
} }
return nullptr; return nullptr;
} }

View File

@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
@ -41,7 +42,7 @@ class SpecializeTransform {
~SpecializeTransform() = default; ~SpecializeTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
std::vector<PrimitivePtr> prim_args) { std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
if (cache_.count(func_graph) == 0) { if (cache_.count(func_graph) == 0) {
cache_[func_graph] = {}; cache_[func_graph] = {};
} }
@ -69,6 +70,13 @@ class SpecializeTransform {
(void)mng->Replace(params[i], arg); (void)mng->Replace(params[i], arg);
continue; continue;
} }
if (value_args[i] != nullptr) {
auto const_tensor = *value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
continue;
}
new_params.push_back(params[i]); new_params.push_back(params[i]);
} }
@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
std::vector<FuncGraphPtr> graph_args; std::vector<FuncGraphPtr> graph_args;
std::vector<PrimitivePtr> prim_args; std::vector<PrimitivePtr> prim_args;
std::vector<tensor::TensorPtr> value_node_args;
std::vector<AnfNodePtr> new_xs; std::vector<AnfNodePtr> new_xs;
bool hasVNode = false; bool hasVNode = false;
for (size_t i = 1; i < inputs.size(); i++) { for (size_t i = 1; i < inputs.size(); i++) {
@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
graph_args.push_back(fg_vnode); graph_args.push_back(fg_vnode);
prim_args.emplace_back(nullptr); prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
hasVNode = true; hasVNode = true;
} else if (IsValueNode<Primitive>(inputs[i])) { } else if (IsValueNode<Primitive>(inputs[i])) {
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
graph_args.emplace_back(nullptr); graph_args.emplace_back(nullptr);
prim_args.push_back(p_vnode); prim_args.push_back(p_vnode);
value_node_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(t_vnode);
hasVNode = true; hasVNode = true;
} else { } else {
graph_args.emplace_back(nullptr); graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr); prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
new_xs.push_back(inputs[i]); new_xs.push_back(inputs[i]);
} }
} }
@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
return nullptr; return nullptr;
} }
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args); auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs); return node->func_graph()->NewCNode(new_xs);
@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor {
private: private:
internal::SpecializeTransform specialize_transform_; internal::SpecializeTransform specialize_transform_;
}; };
// Eliminate unused parameters.
// {G, Xs}
class UnusedParasEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> parameters = fg->parameters();
size_t size = parameters.size();
if (size != inputs.size() - 1) {
return nullptr;
}
std::vector<AnfNodePtr> new_xs;
std::vector<bool> keep_parameters;
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
bool has_unused_para = false;
for (size_t i = 0; i < size; ++i) {
auto iter = node_users.find(parameters[i]);
if (iter != node_users.end() && !iter->second.empty()) {
keep_parameters.push_back(true);
new_xs.push_back(inputs[i + 1]);
continue;
}
keep_parameters.push_back(false);
has_unused_para = true;
}
if (!has_unused_para) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
for (size_t i = 0; i < size; i++) {
if (keep_parameters[i]) {
if (parameters[i]->abstract() != nullptr) {
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
}
new_parameters.push_back(new_fg_parameters[i]);
}
}
mng->SetParameters(new_fg, new_parameters);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);
}
};
// Eliminate unused outputs.
// {G, Xs}
class UnusedOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
auto new_fg_output = new_fg->output();
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
return nullptr;
}
auto output_cnode = new_fg_output->cast<CNodePtr>();
auto &node_users = mng->node_users();
if (node_users.count(node) == 0 || node_users[node].empty()) {
return nullptr;
}
std::unordered_set<int> used_output_idx;
std::vector<std::pair<AnfNodePtr, int>> all_users;
for (auto &node_user : node_users[node]) {
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
return nullptr;
}
auto user_cnode = node_user.first->cast<CNodePtr>();
size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
used_output_idx.insert(used_idx);
all_users.push_back(std::make_pair(node_user.first, used_idx));
}
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
// all output has users.
return nullptr;
}
if (used_output_idx.empty()) {
// we do not process this case.
return nullptr;
} else if (used_output_idx.size() == 1) {
// after eliminate, only one output left.
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
// update users.
for (auto &ret_user : all_users) {
(void)mng->Replace(ret_user.first, node);
}
} else {
// after eliminate, create new multi output.
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
std::unordered_map<int, int> new_idx_map;
for (auto idx : used_output_idx) {
new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1);
new_output_inputs.push_back(output_cnode->input(idx + 1));
}
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
// update users.
for (auto &ret_user : all_users) {
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
}
}
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -89,7 +89,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> { class Optimizer : public std::enable_shared_from_this<Optimizer> {
public: public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {}
virtual ~Optimizer() = default; virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) { void Init(const OptPassGroupMap &passes, bool run_only_once) {
@ -132,6 +132,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
if (!is_enable_) {
return func_graph;
}
// Optimizer step counter; // Optimizer step counter;
int counter = -1; int counter = -1;
bool changes = true; bool changes = true;
@ -171,7 +174,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}; };
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) {
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name = auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
func_graph->DumpFuncGraph(fg_name); func_graph->DumpFuncGraph(fg_name);
@ -211,6 +214,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
void enable_watch_renormalize() { is_watch_renormalize_ = true; } void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; } void disable_watch_renormalize() { is_watch_renormalize_ = false; }
bool is_watch_renormalize() { return is_watch_renormalize_; } bool is_watch_renormalize() { return is_watch_renormalize_; }
void set_enable(bool enable) { is_enable_ = enable; }
private: private:
const std::string name_; const std::string name_;
@ -220,6 +224,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool run_only_once_; bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_; std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_; bool is_watch_renormalize_;
bool is_enable_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
// allreduce fusion only run once // allreduce fusion only run once
root->flags()[ALLREDUCE_FUSION_RUN_ONCE_ONLY] = true; root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
res->results()[pipeline::kStepParallelGraph] = root; res->results()[pipeline::kStepParallelGraph] = root;
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now(); auto end_time = std::chrono::steady_clock::now();

View File

@ -158,8 +158,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) ||
func_graph->flags()[TRAINING]) { func_graph->has_flag(TRAINING)) {
return; return;
} }

View File

@ -107,7 +107,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true; root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
return changes; return changes;
} }

View File

@ -2270,10 +2270,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
if (HasStrategy(root)) { if (HasStrategy(root)) {
MS_LOG(INFO) << "strategies ignored in " << parallel_mode MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
<< ", set_strategy() only valid in [semi_]auto_parallel."; << ", set_strategy() only valid in [semi_]auto_parallel.";
} }
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
} }
return changes; return changes;
@ -2330,11 +2330,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
DumpGraph(root, std::string(STEP_PARALLEL_END)); DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once // step parallel only run once
root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true; root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
res->results()[pipeline::kStepParallelGraph] = root; res->results()[pipeline::kStepParallelGraph] = root;
// in auto parallel mode, no need to check if stategies set // in auto parallel mode, no need to check if stategies set
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
(void)gettimeofday(&end_time, nullptr); (void)gettimeofday(&end_time, nullptr);
uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);

View File

@ -151,7 +151,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.")
.def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.")
.def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.")
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print."); .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

View File

@ -278,7 +278,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
if (bprop_graph != nullptr) { if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
} }
} }
*data = func_graph; *data = func_graph;

View File

@ -1448,15 +1448,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
} }
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
for (auto &item : flags) { for (auto &item : flags) {
if (!py::isinstance<py::str>(item.first) || !py::isinstance<py::bool_>(item.second)) { if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "Type error in flags dict convert"; MS_LOG(ERROR) << "Type error in flags dict convert";
return false; return false;
} }
auto name = py::cast<std::string>(item.first); auto name = py::cast<std::string>(item.first);
auto value = py::cast<bool>(item.second); if (py::isinstance<py::bool_>(item.second)) {
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; auto value = py::cast<bool>(item.second);
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
func_graph->set_flags(name, value); func_graph->set_flag(name, value);
} else if (py::isinstance<py::str>(item.second)) {
auto value = py::cast<std::string>(item.second);
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
func_graph->set_attr(name, MakeValue(value));
} else {
MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
return false;
}
} }
return true; return true;

View File

@ -223,8 +223,8 @@ class Parser {
FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
// In order to keep effect order in the sub-graphs which generated by control flow. // In order to keep effect order in the sub-graphs which generated by control flow.
// We copy the flags from the top graph to the sub-graphs. // We copy the flags from the top graph to the sub-graphs.
if (func_graph_ && !func_graph_->flags().empty()) { if (func_graph_ && !func_graph_->attrs().empty()) {
block->func_graph()->set_flags(func_graph_->flags()); block->func_graph()->set_attrs(func_graph_->attrs());
} }
func_block_list_.push_back(block); func_block_list_.push_back(block);
return block; return block;

View File

@ -25,12 +25,14 @@
#include <functional> #include <functional>
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "debug/anf_ir_utils.h"
#include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse_base.h"
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
#include "pipeline/resource.h" #include "pipeline/resource.h"
#include "pipeline/validator.h" #include "pipeline/validator.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "optimizer/cse.h" #include "optimizer/cse.h"
#include "optimizer/graph_kernel_reuse.h"
#include "optimizer/clean.h" #include "optimizer/clean.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/control_depend.h" #include "optimizer/control_depend.h"
@ -38,6 +40,7 @@
#include "parallel/step_auto_parallel.h" #include "parallel/step_auto_parallel.h"
#include "parallel/allreduce_fusion/step_allreduce_fusion.h" #include "parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace pipeline { namespace pipeline {
@ -162,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
return map; return map;
} }
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig interface_fusion = opt::OptPassConfig({
irpass.mark_interface_fusion_,
});
OptPassGroupMap map({
{"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())},
{"interface_fusion", interface_fusion},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))},
});
return map;
}
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig elim_1 = opt::OptPassConfig({
irpass.addn_eliminate_,
irpass.incorporate_getitem_from_param_,
});
opt::OptPassConfig elim_2 = opt::OptPassConfig({
irpass.unused_parameter_eliminate_,
irpass.unused_output_eliminate_,
});
OptPassGroupMap map({
{"elim_1", elim_1},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"elim_2", elim_2},
});
return map;
}
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
}
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true);
OptPassGroupMap map({ OptPassGroupMap map({
@ -191,8 +228,19 @@ void InitOpt(const ResourcePtr &res) {
opt::irpass::OptimizeIRPassLib irpass; opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
}
} }
} }
} // namespace } // namespace
@ -224,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
bool AddControlDependPass(const ResourcePtr &res) { bool AddControlDependPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
@ -270,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},
{"add_control_depend", AddControlDependPass}, {"cconv", CconvPass},
{"cconv", CconvPass}}; {"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},

View File

@ -488,7 +488,7 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
#ifdef ENABLE_INFER #ifdef ENABLE_INFER
// Now don't use the graph because the exec ge function don't take effect // Now don't use the graph because the exec ge function don't take effect
MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph);
if (ENABLE_TRAIN != info.at(phase)->func_graph->flags()["training"]) { if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) {
MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries";
ConfigManager::GetInstance().ResetConfig(); ConfigManager::GetInstance().ResetConfig();
return py::none(); return py::none();

View File

@ -165,7 +165,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list);
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
return joined_args_spec_list; return joined_args_spec_list;
} }
@ -178,7 +178,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list == args_spec_list)) { if (!(joined_args_spec_list == args_spec_list)) {
trace_.push_back(joined_args_spec_list); trace_.push_back(joined_args_spec_list);
func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
return joined_args_spec_list; return joined_args_spec_list;

View File

@ -479,7 +479,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
if (undetermined_fgs) { if (undetermined_fgs) {
auto fg_parent = fg->parent(); auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent); MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flags(kFuncGraphFlagUndetermined, true); fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
} }
} }

View File

@ -16,6 +16,7 @@
#include "pre_activate/ascend/ascend_backend_optimization.h" #include "pre_activate/ascend/ascend_backend_optimization.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <set>
#include "pre_activate/common/optimizer.h" #include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ir_fission/bn_split.h" #include "pre_activate/ascend/ir_fission/bn_split.h"
#include "pre_activate/ascend/ir_fission/bn_grad_split.h" #include "pre_activate/ascend/ir_fission/bn_grad_split.h"
@ -63,6 +64,9 @@
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" #include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "pre_activate/pass/eliminate_redundant_op.h" #include "pre_activate/pass/eliminate_redundant_op.h"
#include "pre_activate/pass/common_subexpression_elimination.h" #include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/pass/fuse_graph_kernel.h"
#include "pre_activate/pass/fuse_basic.h"
#include "pre_activate/pass/add_atomic_clean.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h" #include "pre_activate/ascend/format_type/merge_cast_to_op.h"
#include "pre_activate/ascend/format_type/check_consistency.h" #include "pre_activate/ascend/format_type/check_consistency.h"
#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h"
@ -88,6 +92,8 @@
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" #include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "pre_activate/ascend/ir_fission/split_fission.h" #include "pre_activate/ascend/ir_fission/split_fission.h"
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
@ -164,6 +170,19 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
} }
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
MS_EXCEPTION_IF_NULL(optimizer);
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
MS_EXCEPTION_IF_NULL(common_process);
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
optimizer->AddPassManager(common_process);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
@ -332,7 +351,94 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
std::string file_path = std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true); DumpIR(file_path, kernel_graph, true);
DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); DumpIRProto(kernel_graph, "after_hwopt");
kernel_graph->DumpFuncGraph("hwopt_d_end");
}
}
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph);
}
// Fuse graph kernels with basic ops
FuseGraphKernel(kernel_graph, is_before_kernel_select);
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}
}
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}
// Fuse basic ops with basic ops
FuseBasic(kernel_graph, is_before_kernel_select);
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" +
std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) +
".ir";
DumpIR(file_path, kernel_graph, true);
}
}
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
return;
}
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" +
std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
AddAtomicClean(kernel_graph);
if (save_graphs) {
std::string file_path =
save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph, true);
} }
} }

View File

@ -24,6 +24,12 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select = false);
void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kernel_graph,
bool is_before_kernel_select = false);
void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt } // namespace opt

View File

@ -22,6 +22,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "kernel/common_utils.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
@ -229,7 +230,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetKernelType(KernelType::TBE_KERNEL);
} else { } else {
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL); builder.SetKernelType(KernelType::AKG_KERNEL);
} }
// if kernel info is null , it remarks this function is running ut // if kernel info is null , it remarks this function is running ut
if (cast->kernel_info() == nullptr) { if (cast->kernel_info() == nullptr) {
@ -284,22 +285,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
TypeId origin_type; const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
TypeId origin_type(kTypeUnknown);
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
};
auto real_input_node = kernel_with_index.first; auto real_input_node = kernel_with_index.first;
if (is_weight_boundary(real_input_node)) { if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// weight // weight
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
if (origin_type == kTypeUnknown) {
origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
}
} else { } else {
// feature map // feature map
origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
@ -307,9 +303,13 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
if (origin_type != device_type) { // In graph kernel, we check parameter,
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
new_inputs.push_back(cur_input);
} else if (origin_type != device_type) {
auto cast = auto cast =
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, origin_type); AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(cnode->scope()); cast->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);

View File

@ -17,9 +17,12 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector>
#include "utils/utils.h" #include "utils/utils.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
#include "kernel/common_utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -74,11 +77,21 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr; return nullptr;
} }
CNodePtr cnode = node->cast<CNodePtr>();
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { std::vector<AnfNodePtr> todos = {node};
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { if (AnfAlgo::IsGraphKernel(node)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(node) << "[" auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
<< node->DebugString() << "]"; MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
}
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) {
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
<< cnode->DebugString() << "]";
}
} }
} }
return nullptr; return nullptr;

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "device/kernel_info.h" #include "device/kernel_info.h"
#include "pre_activate/ascend/ascend_helper.h" #include "pre_activate/ascend/ascend_helper.h"
@ -27,34 +28,45 @@
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "kernel/common_utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list; AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) {
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); AnfNodePtr replace_node = nullptr;
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
auto idx = NewValueNode(SizeToInt(output_idx)); auto idx = NewValueNode(SizeToInt(output_idx));
MS_EXCEPTION_IF_NULL(idx); MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int32Imm>(output_idx); auto imm = std::make_shared<Int32Imm>(output_idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get()); AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
AnfNodePtr replace_node = nullptr; if (need_insert_cast[output_idx]) {
if (origin_type != device_type) { const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
replace_node = TypeId origin_type(kTypeUnknown);
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, origin_type); if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
MS_EXCEPTION_IF_NULL(replace_node); origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
replace_node->set_scope(cnode->scope()); }
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
if (origin_type != device_type) {
replace_node =
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
} else {
replace_node = getitem;
}
} else { } else {
replace_node = getitem; replace_node = getitem;
} }
@ -65,9 +77,10 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL(make_tuple); MS_EXCEPTION_IF_NULL(make_tuple);
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple; return make_tuple;
} } // namespace
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
@ -76,14 +89,23 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL(cnode->Type()); MS_EXCEPTION_IF_NULL(cnode->Type());
// Single output // Single output
if (!cnode->Type()->isa<Tuple>()) { if (!cnode->Type()->isa<Tuple>()) {
if (!need_insert_cast[0]) {
return cnode;
}
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0); const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
TypeId origin_type(kTypeUnknown);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
}
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode; AnfNodePtr replace_node = cnode;
if (origin_type != device_type) { if (origin_type != device_type) {
replace_node = replace_node =
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, origin_type); AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope()); replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
@ -91,7 +113,57 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return replace_node; return replace_node;
} }
// Multiple output // Multiple output
return InsertCastForMultipleOutput(func_graph, cnode); return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
}
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
// insert cast for ops in graph kernel.
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
kernel::GetValidKernelNodes(sub_graph, &todo);
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
for (auto &t : todo) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
// process input
CNodePtr t_cnode = t->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(t_cnode);
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
AnfNodePtr t_new_node_1 = nullptr;
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
// process output
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
if (iter != graph_rets.end()) {
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
}
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
} else {
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
}
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
(void)mng->Replace(t, t_new_node_1);
}
}
// insert cast for graph kernel.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode);
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
} }
} // namespace } // namespace
@ -106,13 +178,27 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
return ProcessGraphKernelOp(func_graph, node);
} else {
// insert cast for single op.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode);
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
}
// insert cast for single op.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input // process input
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode); auto new_node = InsertCastForInput(func_graph, cnode);
// process output // process output
return InsertCastForOutput(func_graph, new_node); return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -133,6 +133,9 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
return nullptr; return nullptr;
} }
auto next_cnode = next_node->cast<CNodePtr>(); auto next_cnode = next_node->cast<CNodePtr>();
if (AnfAlgo::IsGraphKernel(next_node)) {
return nullptr;
}
auto next_op_name = AnfAlgo::GetCNodeName(next_node); auto next_op_name = AnfAlgo::GetCNodeName(next_node);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel_query->Query(next_cnode, &kernel_info_list); kernel_query->Query(next_cnode, &kernel_info_list);
@ -206,6 +209,9 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
return nullptr; return nullptr;
} }
MS_EXCEPTION_IF_NULL(prior_op); MS_EXCEPTION_IF_NULL(prior_op);
if (AnfAlgo::IsGraphKernel(prior_op)) {
return nullptr;
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel_query->Query(prior_op, &kernel_info_list); kernel_query->Query(prior_op, &kernel_info_list);

View File

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
#include <vector>
#include <memory>
#include "utils/utils.h"
#include "pre_activate/common/helper.h"
#include "kernel/common_utils.h"
#include "session/anf_runtime_algorithm.h"
#include "operator/ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
return cnode;
}
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
if (input_shape.size() != 5) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
return nullptr;
}
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrMultiples);
if (multiples.size() == 4 && multiples[1] == 1) {
multiples.push_back(1);
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
}
return cnode;
}
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name == prim::kPrimTile->name()) {
return ModifyTileOpAttrs(cnode);
} else if (op_name == prim::kPrimReduceSum->name()) {
// kPrimReduceMean
// kPrimReduceSum
// kPrimReduceAll
// kPrimReduceMax
// kPrimReduceMin
return ModifyReduceOpsAttrs(cnode);
}
return nullptr;
}
} // namespace
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class ModifyOpAttrs : public PatternProcessPass {
public:
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
~ModifyOpAttrs() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H

View File

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
#include <vector>
#include <memory>
#include "pre_activate/common/helper.h"
#include "kernel/common_utils.h"
#include "session/anf_runtime_algorithm.h"
#include "operator/ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != prim::kPrimReshape->name()) {
return nullptr;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
return cnode->input(1);
}
} // namespace
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class RemoveNoUseReshapeOp : public PatternProcessPass {
public:
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
~RemoveNoUseReshapeOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H

View File

@ -121,6 +121,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<CNodePtr> cast_nodes; std::vector<CNodePtr> cast_nodes;

View File

@ -102,9 +102,12 @@ bool UnVisited(const BaseRef &n) {
auto prim_py = value->cast<PrimitivePtr>(); auto prim_py = value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim_py); MS_EXCEPTION_IF_NULL(prim_py);
return !prim_py->HasAttr(kAttrVisited); return !prim_py->HasAttr(kAttrVisited);
} else { } else if (IsValueNode<FuncGraph>(in)) {
return false; auto func_graph = GetValueNode<FuncGraphPtr>(in);
MS_EXCEPTION_IF_NULL(func_graph);
return !func_graph->has_flag(kAttrVisited);
} }
return false;
} }
return false; return false;
} }
@ -188,9 +191,12 @@ bool Visited(const BaseRef &n) {
auto prim_py = value->cast<PrimitivePtr>(); auto prim_py = value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim_py); MS_EXCEPTION_IF_NULL(prim_py);
return prim_py->HasAttr(kAttrVisited); return prim_py->HasAttr(kAttrVisited);
} else { } else if (IsValueNode<FuncGraph>(in)) {
return false; auto func_graph = GetValueNode<FuncGraphPtr>(in);
MS_EXCEPTION_IF_NULL(func_graph);
return func_graph->has_flag(kAttrVisited);
} }
return false;
} }
return false; return false;
} }

Some files were not shown because too many files have changed in this diff Show More