forked from mindspore-Ecosystem/mindspore
!3775 remove the dtype convert when update output
Merge pull request !3775 from lianliguang/test-xiu-bug
This commit is contained in:
commit
73c4022ef4
|
@ -51,33 +51,19 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
|
|||
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
|
||||
AnfNodePtr trans_node = nullptr;
|
||||
AnfNodePtr input_node = nullptr;
|
||||
CNodePtr trans_data = nullptr;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
|
||||
std::vector<Axis> padding_axis;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if insert transdata for input we need to change the input
|
||||
if (is_insert_input) {
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
|
||||
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
|
||||
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
|
||||
} else {
|
||||
input_node = node;
|
||||
padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
|
||||
}
|
||||
// Init
|
||||
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
|
||||
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index);
|
||||
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT;
|
||||
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
|
||||
: AnfAlgo::GetOutputReshapeType(node, insert_index);
|
||||
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
|
||||
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
|
||||
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
|
||||
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
|
||||
|
||||
auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
|
||||
bool need_padding = false;
|
||||
if (is_insert_input) {
|
||||
need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size()));
|
||||
} else {
|
||||
need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size()));
|
||||
}
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
|
@ -89,6 +75,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
|
||||
trans_node = trans_data;
|
||||
trans_data->set_abstract(input_node->abstract());
|
||||
} else {
|
||||
// if need padding & is output need insert a transdata
|
||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||
|
@ -303,7 +290,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
|
||||
auto real_input_node = kernel_with_index.first;
|
||||
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// weight
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace opt {
|
|||
class RectifyDoMaskKernelInfo : public PatternProcessPass {
|
||||
public:
|
||||
explicit RectifyDoMaskKernelInfo(bool multigraph = true)
|
||||
: PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared<KernelSelect>()) {}
|
||||
: PatternProcessPass("rectify_do_mask_kernel_info", multigraph),
|
||||
kernel_selecter(std::make_shared<KernelSelect>()) {}
|
||||
~RectifyDoMaskKernelInfo() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
|
@ -87,6 +87,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
|
|||
new_transdata_node =
|
||||
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
|
||||
new_transdata_node->set_abstract(node->abstract());
|
||||
new_replace_node = new_transdata_node;
|
||||
}
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -32,21 +34,21 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_op_1);
|
||||
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(out_reshape);
|
||||
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
|
||||
if (IsUsedByOthers(func_graph, reshape_op_1)) {
|
||||
if (IsUsedByOthers(func_graph, out_reshape)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_op_2);
|
||||
if (IsUsedByOthers(func_graph, reshape_op_2)) {
|
||||
auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum);
|
||||
MS_EXCEPTION_IF_NULL(in_reshape);
|
||||
if (IsUsedByOthers(func_graph, in_reshape)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0);
|
||||
if (input_shape == output_shape) {
|
||||
auto input_node = reshape_op_2->input(1);
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0);
|
||||
if (kernel::IsSameShape(input_shape, output_shape)) {
|
||||
auto input_node = AnfAlgo::GetInputNode(in_reshape, 0);
|
||||
return input_node;
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -71,7 +71,8 @@ bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) {
|
|||
|
||||
bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) {
|
||||
return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) &&
|
||||
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0);
|
||||
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) &&
|
||||
kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0));
|
||||
}
|
||||
|
||||
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode,
|
||||
|
|
|
@ -106,12 +106,12 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
|
||||
// if node is a value node, no need sync addr from device to host
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return value_node->value();
|
||||
}
|
||||
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return value_node->value();
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
|
||||
if (input_idx >= input_tensors.size()) {
|
||||
|
@ -252,6 +252,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
|
|||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
|
||||
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
|
||||
AnfAlgo::SetOutputAddr(device_address, 0, param.get());
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
|
||||
// construct abstract of parameter
|
||||
|
|
|
@ -481,13 +481,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
|||
if (op_info != nullptr) {
|
||||
is_ref = op_info->is_ref();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
|
||||
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
|
||||
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown &&
|
||||
AnfAlgo::OutputAddrExist(real_input_node, 0)) {
|
||||
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
|
||||
|
|
Loading…
Reference in New Issue