!14564 fix cpu operator with unsupported type
From: @huaweib Reviewed-by: @jjfeing,@kisnwang Signed-off-by: @jjfeing
This commit is contained in:
commit
f2fd3c5e85
|
@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) {
|
|||
template <typename S, typename T>
|
||||
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
|
||||
target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
||||
source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
template <typename S, typename T>
|
||||
|
@ -45,7 +45,6 @@ bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
S *input = reinterpret_cast<S *>(inputs[0]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
|
||||
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
Cast<S, T>(input, output, lens);
|
||||
return true;
|
||||
|
|
|
@ -27,7 +27,7 @@ void MaximumGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
dy_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (!x_shape_.size() || !y_shape_.size() || !dout_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Input NULL";
|
||||
}
|
||||
|
|
|
@ -36,6 +36,11 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||
-Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move")
|
||||
endif()
|
||||
|
||||
if(ENABLE_CPU)
|
||||
file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc")
|
||||
list(APPEND _PREACTIVATE_SRC_LIST ${_CPU_SRC_LIST})
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS
|
||||
SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT)
|
||||
add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST})
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/cpu/insert_cast_cpu.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
|
||||
const TypeId &input_type, const TypeId &output_type,
|
||||
const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::string input_format = format;
|
||||
std::string output_format = format;
|
||||
CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input});
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
// set kernel build info
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
|
||||
// if kernel info is null , it remarks this function is running ut
|
||||
if (cast->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
cast->set_kernel_info(kernel_info);
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
|
||||
return cast;
|
||||
}
|
||||
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
size_t out_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
||||
AnfNodePtr replace_node = nullptr;
|
||||
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
auto idx = NewValueNode(SizeToLong(output_idx));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(output_idx);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
|
||||
if (need_insert_cast[output_idx]) {
|
||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||
if (infer_type != device_type) {
|
||||
replace_node =
|
||||
AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, infer_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto mng = kernel_graph->manager();
|
||||
for (size_t input_index = 0; input_index < in_num; ++input_index) {
|
||||
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
|
||||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
||||
|
||||
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); infer_type != device_type) {
|
||||
auto cast =
|
||||
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, infer_type, device_type, origin_shape, device_type);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
cast->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
|
||||
mng->Replace(cur_input, cast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
||||
return cnode;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cnode->Type());
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
// Single output
|
||||
if (!cnode->Type()->isa<Tuple>()) {
|
||||
if (!need_insert_cast[0]) {
|
||||
return cnode;
|
||||
}
|
||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
|
||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
||||
AnfNodePtr replace_node = cnode;
|
||||
if (infer_type != device_type) {
|
||||
replace_node =
|
||||
AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, infer_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node);
|
||||
}
|
||||
}
|
||||
return replace_node;
|
||||
}
|
||||
// Multiple output
|
||||
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef InsertCastCPU::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<CondVar>(UnVisited);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr InsertCastCPU::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, cnode, std::vector<bool>(AnfAlgo::GetOutputTensorNum(cnode), true));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class InsertCastCPU : public PatternProcessPass {
|
||||
public:
|
||||
explicit InsertCastCPU(bool multigraph = true) : PatternProcessPass("insert_cast_cpu", multigraph) {}
|
||||
~InsertCastCPU() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H
|
|
@ -27,7 +27,9 @@
|
|||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/common/pass_manager.h"
|
||||
#include "backend/optimizer/cpu/insert_cast_cpu.h"
|
||||
#include "backend/optimizer/pass/replace_node_by_proxy.h"
|
||||
#include "backend/optimizer/pass/erase_visit_attr.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
|
@ -61,9 +63,21 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos
|
|||
void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
std::string pass_name = "replace_node_by_proxy";
|
||||
pass_name.append(std::to_string(graph_sum_));
|
||||
pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name));
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) {
|
||||
AssignParamKey(kernel_graph);
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
std::string pass_name = "replace_node_by_proxy";
|
||||
pass_name.append(std::to_string(graph_sum_));
|
||||
pm->AddPass(std::make_shared<opt::ReplaceNodeByProxy>(pass_name));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
|
||||
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
|
||||
MS_LOG(INFO) << "insert cast pass";
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
@ -77,14 +91,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr
|
|||
graph->UpdateGraphDynamicAttr();
|
||||
MS_LOG(INFO) << "Set kernel info";
|
||||
SetKernelInfo(graph.get());
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PSContext::instance()->is_ps_mode()) {
|
||||
AssignParamKey(graph);
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
Optimize(graph);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
MS_LOG(INFO) << "Set kernel info end";
|
||||
Optimize(graph);
|
||||
MS_LOG(INFO) << "Build kernel";
|
||||
BuildKernel(graph.get());
|
||||
|
||||
|
@ -158,6 +166,7 @@ void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
SetKernelInfo(kernel_graph.get());
|
||||
Optimize(kernel_graph);
|
||||
BuildKernel(kernel_graph.get());
|
||||
run_op_graphs_[graph_info] = kernel_graph;
|
||||
}
|
||||
|
|
|
@ -35,21 +35,6 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes,
|
||||
const CNodePtr kernel_node) {
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<TypeId> output_types;
|
||||
output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get());
|
||||
}
|
||||
}
|
||||
|
||||
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
|
||||
std::vector<TypeId> *output_types) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -142,35 +127,11 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|||
int format_matched_num = 0;
|
||||
auto input_num = input_types.size();
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
|
||||
[i](size_t index) { return index == i; });
|
||||
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size());
|
||||
if (have_cnode_input && is_not_cnode_idx) {
|
||||
data_type_matched_num++;
|
||||
format_matched_num++;
|
||||
continue;
|
||||
}
|
||||
if (is_not_cnode_idx) {
|
||||
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
||||
<< ", actual input dtype:" << input_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
format_matched_num++;
|
||||
continue;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
||||
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
||||
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
||||
<< ", actual input dtype:" << input_types[i];
|
||||
} else {
|
||||
data_type_matched_num++;
|
||||
}
|
||||
|
||||
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
||||
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
|
||||
<< ", actual input format:" << input_formats[i];
|
||||
} else {
|
||||
format_matched_num++;
|
||||
}
|
||||
}
|
||||
|
@ -320,9 +281,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
(matched.first || input_types.size() == input_not_cnode_indexes.size())) {
|
||||
MS_LOG(INFO) << "Input format and dtype is matched";
|
||||
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first;
|
||||
for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) {
|
||||
input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first;
|
||||
}
|
||||
}
|
||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
||||
|
|
Loading…
Reference in New Issue