!14564 fix cpu operator with unsupported type

From: @huaweib
Reviewed-by: @jjfeing,@kisnwang
Signed-off-by: @jjfeing
This commit is contained in:
mindspore-ci-bot 2021-04-06 16:59:15 +08:00 committed by Gitee
commit f2fd3c5e85
7 changed files with 241 additions and 58 deletions

View File

@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) {
template <typename S, typename T>
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
}
template <typename S, typename T>
@ -45,7 +45,6 @@ bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
S *input = reinterpret_cast<S *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
Cast<S, T>(input, output, lens);
return true;

View File

@ -27,7 +27,7 @@ void MaximumGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
dy_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (!x_shape_.size() || !y_shape_.size() || !dout_shape.size()) {
MS_LOG(EXCEPTION) << "Input NULL";
}

View File

@ -36,6 +36,11 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
-Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move")
endif()
if(ENABLE_CPU)
file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
list(APPEND _PREACTIVATE_SRC_LIST ${_CPU_SRC_LIST})
endif()
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})

View File

@ -0,0 +1,174 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/cpu/insert_cast_cpu.h"
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "utils/utils.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type,
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
MS_EXCEPTION_IF_NULL(func_graph);
std::string input_format = format;
std::string output_format = format;
CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
MS_EXCEPTION_IF_NULL(cast);
// set kernel build info
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
// if kernel info is null , it remarks this function is running ut
if (cast->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
cast->set_kernel_info(kernel_info);
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
return cast;
}
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
AnfNodePtr replace_node = nullptr;
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
auto idx = NewValueNode(SizeToLong(output_idx));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(output_idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
if (need_insert_cast[output_idx]) {
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
if (infer_type != device_type) {
replace_node =
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, infer_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
}
}
}
return cnode;
}
void InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
auto mng = kernel_graph->manager();
for (size_t input_index = 0; input_index < in_num; ++input_index) {
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); infer_type != device_type) {
auto cast =
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, infer_type, device_type, origin_shape, device_type);
MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
mng->Replace(cur_input, cast);
}
}
}
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
return cnode;
}
MS_EXCEPTION_IF_NULL(cnode->Type());
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// Single output
if (!cnode->Type()->isa<Tuple>()) {
if (!need_insert_cast[0]) {
return cnode;
}
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode;
if (infer_type != device_type) {
replace_node =
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, infer_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node);
}
}
return replace_node;
}
// Multiple output
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
}
} // namespace
const BaseRef InsertCastCPU::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
const AnfNodePtr InsertCastCPU::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
InsertCastForInput(func_graph, cnode);
// process output
return InsertCastForOutput(func_graph, cnode, std::vector<bool>(AnfAlgo::GetOutputTensorNum(cnode), true));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "ir/anf.h"
namespace mindspore {
namespace opt {
class InsertCastCPU : public PatternProcessPass {
public:
explicit InsertCastCPU(bool multigraph = true) : PatternProcessPass("insert_cast_cpu", multigraph) {}
~InsertCastCPU() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H

View File

@ -27,7 +27,9 @@
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/cpu/insert_cast_cpu.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#include "backend/optimizer/pass/erase_visit_attr.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
@ -61,9 +63,21 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
std::string pass_name = "replace_node_by_proxy";
pass_name.append(std::to_string(graph_sum_));
pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name));
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
AssignParamKey(kernel_graph);
if (ps::PSContext::instance()->is_worker()) {
std::string pass_name = "replace_node_by_proxy";
pass_name.append(std::to_string(graph_sum_));
pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name));
}
}
#endif
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
MS_LOG(INFO) << "insert cast pass";
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
@ -77,14 +91,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
graph->UpdateGraphDynamicAttr();
MS_LOG(INFO) << "Set kernel info";
SetKernelInfo(graph.get());
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PSContext::instance()->is_ps_mode()) {
AssignParamKey(graph);
if (ps::PSContext::instance()->is_worker()) {
Optimize(graph);
}
}
#endif
MS_LOG(INFO) << "Set kernel info end";
Optimize(graph);
MS_LOG(INFO) << "Build kernel";
BuildKernel(graph.get());
@ -158,6 +166,7 @@ void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
SetKernelInfo(kernel_graph.get());
Optimize(kernel_graph);
BuildKernel(kernel_graph.get());
run_op_graphs_[graph_info] = kernel_graph;
}

View File

@ -35,21 +35,6 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
return false;
}
void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes,
const CNodePtr kernel_node) {
for (auto &input_index : input_not_cnode_indexes) {
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
std::vector<TypeId> output_types;
output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetOutputsFormat({kOpFormat_DEFAULT});
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get());
}
}
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
@ -142,35 +127,11 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
int format_matched_num = 0;
auto input_num = input_types.size();
for (size_t i = 0; i < input_num; ++i) {
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
[i](size_t index) { return index == i; });
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size());
if (have_cnode_input && is_not_cnode_idx) {
data_type_matched_num++;
format_matched_num++;
continue;
}
if (is_not_cnode_idx) {
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
<< ", actual input dtype:" << input_types[i];
} else {
data_type_matched_num++;
}
format_matched_num++;
continue;
}
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
<< ", actual input dtype:" << input_types[i];
} else {
data_type_matched_num++;
}
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
<< ", actual input format:" << input_formats[i];
} else {
format_matched_num++;
}
}
@ -320,9 +281,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
MS_LOG(INFO) << "Input format and dtype is matched";
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
for (auto &input_index : input_not_cnode_indexes) {
input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first;
for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) {
input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first;
}
}
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());